Compare commits

..

83 Commits

Author SHA1 Message Date
刘祥超
17af07cce0 edge-node gc命令增加耗时和gc pause时长 2023-12-25 09:24:20 +08:00
刘祥超
cfa57fac66 优化计数器相关测试用例 2023-12-24 16:08:57 +08:00
刘祥超
47523eaa73 优化计数器性能 2023-12-24 15:11:09 +08:00
刘祥超
27a24c6a8a 版本号修改为1.3.2 2023-12-24 11:14:45 +08:00
刘祥超
9bc2b1a651 WAF参数中增加“请求来源” 2023-12-24 10:03:24 +08:00
刘祥超
4f24b7f39c 增加Websocket连接数统计 2023-12-20 11:43:00 +08:00
刘祥超
4607a1f4e7 版本号修改为1.3.1.2 2023-12-18 08:51:22 +08:00
刘祥超
0f2068b161 优化TCP源站错误提示 2023-12-15 18:38:09 +08:00
刘祥超
c039691a71 缓存设置中可以设置缓存主域名,用来复用多域名下的缓存 2023-12-13 18:41:51 +08:00
刘祥超
930ee44065 根据系统环境调整WebP转换线程数 2023-12-12 09:55:18 +08:00
刘祥超
8a9aac7d72 优化代码 2023-12-11 20:35:48 +08:00
刘祥超
e50bbb962d WebP策略变化时只更新相关配置 2023-12-11 11:09:12 +08:00
刘祥超
9ff936d0c1 WebP转换质量转移到WebP策略配置 2023-12-11 10:17:17 +08:00
刘祥超
f53727b09c WebP转换限制为单线程,防止占用系统资源过高 2023-12-11 09:33:04 +08:00
刘祥超
525ce1f923 优化WAF XSS检测,减少对图片内容的误判 2023-12-10 19:40:29 +08:00
刘祥超
16e7cd800c WAF SQL注入检测和XSS注入检测自动进行URL解码 2023-12-10 16:52:54 +08:00
刘祥超
3f34bfc0b0 节点进程停止时,自动保存WAF临时白名单,并在进程重新启动后恢复 2023-12-10 15:41:31 +08:00
刘祥超
548cd1002b 增加WAF相关测试用例 2023-12-10 09:27:29 +08:00
刘祥超
3423865868 优化测试用例 2023-12-10 08:54:39 +08:00
刘祥超
037bc8e0de 优化WAF单词匹配性能 2023-12-09 19:19:29 +08:00
刘祥超
e03292de28 WAF规则模板中XSS注入检测规则使用“包含XSS注入”操作符替代以往的正则表达式 2023-12-09 17:00:21 +08:00
刘祥超
ee2565905e 优化WAF动作“显示网页”显示 2023-12-09 15:55:40 +08:00
刘祥超
05881b457d WAF规则模板中SQL注入规则使用“包含SQL注入”操作符替代以往的正则表达式 2023-12-09 15:28:07 +08:00
刘祥超
b116effc6c WAF SQL注入和XSS检测增加缓存/优化部分WAF相关测试用例 2023-12-09 11:46:50 +08:00
刘祥超
536efeeb9c 提升单词匹配性能 2023-12-09 10:06:07 +08:00
刘祥超
e8638e4bec WAF检查项增加“所有报头名称” 2023-12-08 15:39:23 +08:00
刘祥超
c9db722129 WAF增加“包含XSS注入”操作符 2023-12-08 10:15:18 +08:00
刘祥超
90de472bd5 增加测试用例 2023-12-07 20:47:25 +08:00
刘祥超
50c6c60abf WAF SQL注入检测时支持 (http|https):// 开头的URL 2023-12-07 20:38:06 +08:00
刘祥超
cc10372fe1 WAF增加“包含SQL注入”操作符 2023-12-07 20:25:35 +08:00
刘祥超
05c98a0656 修复一处单词错误 2023-12-07 12:14:04 +08:00
刘祥超
1a790fe391 优化代码 2023-12-07 12:07:06 +08:00
刘祥超
7dbd73cb59 优化WAF中前缀和后缀相关操作符性能 2023-12-07 12:05:08 +08:00
刘祥超
4dfa571547 WAF操作符增加包含任一单词、包含所有单词、不包含任一单词 2023-12-07 11:42:59 +08:00
刘祥超
9f77f62308 WAF checkpoint返回值支持[][]byte 2023-12-05 17:18:53 +08:00
刘祥超
facea1ed96 优化代码 2023-12-05 16:28:10 +08:00
刘祥超
e367814db3 内容压缩级别允许为0 2023-12-05 10:48:17 +08:00
刘祥超
3a15408c98 修复缓存命中率统计测试用例 2023-12-03 14:55:09 +08:00
刘祥超
c504b37118 WAF相关跳转不计入统计 2023-12-03 14:41:11 +08:00
刘祥超
74708dc02f 默认不启用内存分片管理 2023-12-03 14:26:51 +08:00
刘祥超
0c097498bb 优化链表相关代码 2023-12-03 11:27:47 +08:00
刘祥超
981c063eff 优化验证码性能 2023-11-30 17:25:41 +08:00
刘祥超
5e35c50113 页面优化增加例外URL和限制URL 2023-11-30 15:48:50 +08:00
刘祥超
e6c2869ff2 增加“极验-行为验”验证码集成支持 2023-11-29 17:00:06 +08:00
刘祥超
358bec2e9b WAF验证码验证后返回时判断是否已通过验证 2023-11-28 20:39:42 +08:00
刘祥超
1cd644f2eb 优化验证码加载方式,减少不必要的图片生成 2023-11-28 18:07:27 +08:00
刘祥超
f783e5c331 将版本号修改为1.3.1 2023-11-23 17:19:41 +08:00
刘祥超
c39b1c794f 修复清空文件索引Map时产生并发异常 2023-11-23 17:14:50 +08:00
刘祥超
2633d43897 增加最大内存用量 2023-11-22 17:03:42 +08:00
刘祥超
88dca006c4 优化日志 2023-11-22 16:44:06 +08:00
刘祥超
98feb26b79 优化brotli压缩和解压缩性能 2023-11-21 20:18:37 +08:00
刘祥超
ac6683e79d GRPC增加Keepalive参数 2023-11-20 09:56:50 +08:00
刘祥超
99d24afbcd 验证码验证不区分访问路径 2023-11-19 15:34:22 +08:00
刘祥超
ba19a9f4c4 减少一些不必要的访问统计 2023-11-19 09:10:37 +08:00
刘祥超
7fea67a2b5 区域封禁支持观察者模式 2023-11-18 15:02:58 +08:00
刘祥超
ecd2e6955e 当SNI无法读取到ServerName时,尝试使用节点IP搜索网站 2023-11-18 12:08:51 +08:00
刘祥超
09d60a3047 优化内存缓存最大值算法 2023-11-17 19:12:24 +08:00
刘祥超
e24f390412 优化人机识别样式 2023-11-16 08:57:20 +08:00
刘祥超
eeacec1a4e 人机识别增加UA记录 2023-11-16 08:44:07 +08:00
刘祥超
30cd6373c5 修复WAF相关单元测试 2023-11-16 08:43:31 +08:00
刘祥超
87a6ab0559 源站支持404内容自动重试其他源站 2023-11-15 19:06:15 +08:00
刘祥超
59f27215d3 使用泛型优化计数器内存 2023-11-15 15:57:41 +08:00
刘祥超
768384dcf0 优化计数器 2023-11-15 15:17:03 +08:00
刘祥超
3b52ac0fd2 WAF人机识别实现点击验证和滑动解锁验证/单个网站可以设置默认的人机识别方式 2023-11-15 15:10:25 +08:00
刘祥超
41343b2264 版本号修改为1.3.0 2023-11-14 14:47:11 +08:00
刘祥超
d084059f04 缓存索引数据库取消最后访问时间,以提升某些查询速度 2023-11-13 21:43:25 +08:00
刘祥超
9253c44ba5 使用utils.CutPrefix代替strings.CutPrefix 2023-11-13 18:17:32 +08:00
刘祥超
ddec0bf2e0 限制请求域名长度不超过253 2023-11-13 17:20:46 +08:00
刘祥超
aeba1805af 限制统计数据中域名长度 2023-11-13 17:07:55 +08:00
刘祥超
ecff37e080 优化计数器代码 2023-11-13 15:11:11 +08:00
刘祥超
d31dac75be 自定义页面增加例外URL和限制URL设置 2023-11-13 10:46:26 +08:00
刘祥超
4571c84102 自定义页面增加“跳转URL”功能 2023-11-10 16:36:35 +08:00
刘祥超
6a9f59bee0 修复访问节点自定义内容可能无法生效的问题 2023-11-10 11:41:45 +08:00
刘祥超
f1951869f1 URL跳转中增加例外域名和仅限域名 2023-11-10 11:06:24 +08:00
刘祥超
cfd4195c0f 读取缓存时可以使用源站的ETag 2023-11-09 18:20:32 +08:00
刘祥超
d793472b42 调整缓存索引数据库缓存尺寸 2023-11-06 22:10:34 +08:00
刘祥超
1e56247b9c 调整缓存索引数据库缓存尺寸 2023-11-06 20:26:57 +08:00
刘祥超
c34a38857a 增加测试用例 2023-11-06 18:36:11 +08:00
刘祥超
57fa7036dc 修复磁盘占用统计计算错误 2023-11-03 11:51:53 +08:00
刘祥超
b8a3ac750f 上传域名统计时,限制域名长度不能超过64位 2023-11-02 17:23:39 +08:00
刘祥超
9d6692db0c 进一步缩短缓存Key临时缓存时间 2023-11-02 14:14:28 +08:00
刘祥超
ad94327226 实现网络数据包相关统计(商业版本) 2023-10-26 17:18:42 +08:00
刘祥超
aee1ff9609 更新库 2023-10-26 09:53:23 +08:00
109 changed files with 18372 additions and 1371 deletions

View File

@@ -6,10 +6,11 @@ function build() {
VERSION=$(lookup-version "$ROOT"/../internal/const/const.go)
DIST=$ROOT/"../dist/${NAME}"
MUSL_DIR="/usr/local/opt/musl-cross/bin"
SRCDIR=$(realpath "$ROOT/..")
# for macOS users: precompiled gcc can be downloaded from https://github.com/messense/homebrew-macos-cross-toolchains
GCC_X86_64_DIR="/usr/local/gcc/x86_64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="//usr/local/gcc/aarch64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="/usr/local/gcc/aarch64-unknown-linux-gnu/bin"
OS=${1}
ARCH=${2}
@@ -70,6 +71,8 @@ function build() {
CC_PATH=""
CXX_PATH=""
CGO_LDFLAGS=""
CGO_CFLAGS=""
BUILD_TAG=$TAG
if [[ `uname -a` == *"Darwin"* && "${OS}" == "linux" ]]; then
if [ "${ARCH}" == "amd64" ]; then
@@ -79,7 +82,7 @@ function build() {
CC_PATH="x86_64-unknown-linux-gnu-gcc"
CXX_PATH="x86_64-unknown-linux-gnu-g++"
if [ "$TAG" = "plus" ]; then
BUILD_TAG="plus,script"
BUILD_TAG="plus,script,packet"
fi
else
CC_PATH="x86_64-linux-musl-gcc"
@@ -97,7 +100,7 @@ function build() {
CC_PATH="aarch64-unknown-linux-gnu-gcc"
CXX_PATH="aarch64-unknown-linux-gnu-g++"
if [ "$TAG" = "plus" ]; then
BUILD_TAG="plus,script"
BUILD_TAG="plus,script,packet"
fi
else
CC_PATH="aarch64-linux-musl-gcc"
@@ -117,13 +120,26 @@ function build() {
CXX_PATH="mips64el-linux-musl-g++"
fi
fi
# libpcap
if [ "$OS" == "linux" ] && [[ "$ARCH" == "amd64" || "$ARCH" == "arm64" ]] && [ "$TAG" == "plus" ]; then
CGO_LDFLAGS="-L${SRCDIR}/libs/libpcap/${ARCH} -lpcap -L${SRCDIR}/libs/libbrotli/${ARCH} -lbrotlienc -lbrotlidec -lbrotlicommon"
CGO_CFLAGS="-I${SRCDIR}/libs/libpcap/src/libpcap -I${SRCDIR}/libs/libpcap/src/libpcap/pcap -I${SRCDIR}/libs/libbrotli/src/brotli/c/include"
fi
if [ ! -z $CC_PATH ]; then
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" "$ROOT"/../cmd/edge-node/main.go
env CC=$MUSL_DIR/$CC_PATH \
CXX=$MUSL_DIR/$CXX_PATH GOOS="${OS}" \
GOARCH="${ARCH}" CGO_ENABLED=1 \
CGO_LDFLAGS="${CGO_LDFLAGS}" \
CGO_CFLAGS="${CGO_CFLAGS}" \
go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" "$ROOT"/../cmd/edge-node/main.go
else
if [[ `uname` == *"Linux"* ]] && [ "$OS" == "linux" ] && [[ "$ARCH" == "amd64" || "$ARCH" == "arm64" ]] && [ "$TAG" == "plus" ]; then
BUILD_TAG="plus,script"
BUILD_TAG="plus,script,packet"
fi
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-node/main.go
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 CGO_LDFLAGS="${CGO_LDFLAGS}" CGO_CFLAGS="${CGO_CFLAGS}" go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-node/main.go
fi
if [ ! -f "${DIST}/bin/${NAME}" ]; then

View File

@@ -228,11 +228,18 @@ func main() {
})
app.On("gc", func() {
var sock = gosock.NewTmpSock(teaconst.ProcessName)
_, err := sock.Send(&gosock.Command{Code: "gc"})
reply, err := sock.Send(&gosock.Command{Code: "gc"})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
fmt.Println("ok")
if reply == nil {
fmt.Println("ok")
} else {
var paramMap = maps.NewMap(reply.Params)
var pauseMS = paramMap.GetFloat64("pauseMS")
var costMS = paramMap.GetFloat64("costMS")
fmt.Printf("ok, cost: %.4fms, pause: %.4fms", costMS, pauseMS)
}
}
})
app.On("ip.drop", func() {

29
go.mod
View File

@@ -25,28 +25,27 @@ require (
github.com/iwind/TeaGo v0.0.0-20230630104525-161f0b32996d
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4
github.com/iwind/gowebp v0.0.0-20230927084601-21954d2e229f
github.com/klauspost/compress v1.16.5
github.com/iwind/gowebp v0.0.0-20231026013903-1c22b0d78cc4
github.com/klauspost/compress v1.17.2
github.com/mattn/go-sqlite3 v1.14.17
github.com/mdlayher/netlink v1.7.1
github.com/miekg/dns v1.1.43
github.com/mssola/useragent v1.0.0
github.com/pires/go-proxyproto v0.6.1
github.com/qiniu/go-sdk/v7 v7.16.0
github.com/quic-go/quic-go v0.39.0
github.com/quic-go/quic-go v0.39.2
github.com/shirou/gopsutil/v3 v3.22.2
github.com/tencentyun/cos-go-sdk-v5 v0.7.41
golang.org/x/image v0.7.0
golang.org/x/image v0.13.0
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
google.golang.org/grpc v1.55.0
google.golang.org/protobuf v1.30.0
google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chai2010/webp v1.1.1 // indirect
github.com/clbanning/mxj v1.8.4 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
@@ -54,7 +53,7 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/go-querystring v1.0.0 // indirect
github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8 // indirect
github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/native v1.0.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
@@ -62,10 +61,10 @@ require (
github.com/mdlayher/socket v0.4.0 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/mozillazg/go-httpheader v0.2.1 // indirect
github.com/onsi/ginkgo/v2 v2.12.1 // indirect
github.com/onsi/ginkgo/v2 v2.13.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-20 v0.3.4 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/tdewolff/minify/v2 v2.12.7 // indirect
github.com/tdewolff/parse/v2 v2.6.6 // indirect
github.com/tklauser/go-sysconf v0.3.9 // indirect
@@ -73,11 +72,11 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/mod v0.13.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.13.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
golang.org/x/tools v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect
)

74
go.sum
View File

@@ -15,8 +15,6 @@ github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927/go.mod h1:h/aW8ynjgkuj+NQRlZcDbAbM1ORAbXjXX77sX7T289U=
github.com/clbanning/mxj v1.8.4 h1:HuhwZtbyvyOw+3Z1AowPkU87JkJUSv751ELWaiTpj8I=
github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng=
@@ -55,10 +53,10 @@ github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASu
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8 h1:gpptm606MZYGaMHMsB4Srmb6EbW/IVHnt04rcMXnkBQ=
github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b h1:RMpPgZTSApbPf7xaVel+QkoGPRLFLrwFO89uDUHEGf0=
github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/huaweicloud/huaweicloud-sdk-go-obs v3.23.4+incompatible h1:XRAk4HBDLCYEdPLWtKf5iZhOi7lfx17aY0oSO9+mcg8=
github.com/huaweicloud/huaweicloud-sdk-go-obs v3.23.4+incompatible/go.mod h1:l7VUhRbTKCzdOacdT4oWCwATKyvZqUOlOqr0Ous3k4s=
github.com/iwind/TeaGo v0.0.0-20230630104525-161f0b32996d h1:XnTIj781NdSipts60fVqbgZorVAKVSRaA6nqVNfBQ1g=
@@ -69,8 +67,8 @@ github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 h1:DaQjoWZhLNxjhIXedV
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11/go.mod h1:JtbX20untAjUVjZs1ZBtq80f5rJWvwtQNRL6EnuYRnY=
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4 h1:VWGsCqTzObdlbf7UUE3oceIpcEKi4C/YBUszQXk118A=
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA=
github.com/iwind/gowebp v0.0.0-20230927084601-21954d2e229f h1:DCUsOhpZbuKiROTZGc9V9z1uEfm+EbU5nhze+Tv5xo0=
github.com/iwind/gowebp v0.0.0-20230927084601-21954d2e229f/go.mod h1:Re7TEhwL+ygnxFg52fC0PWy01ULAIZp2QR0q5WwEOQA=
github.com/iwind/gowebp v0.0.0-20231026013903-1c22b0d78cc4 h1:eyymORsZg0tZ0niyolYF4nao4sdNUI+Ll40s96tKHBY=
github.com/iwind/gowebp v0.0.0-20231026013903-1c22b0d78cc4/go.mod h1:AYyXDhbbD7q9N6rJff2jrE7pGupaiyvtv3YeyIAQLXk=
github.com/iwind/nftables v0.0.0-20230419014751-9f023a644ad4 h1:RPAH9Sj9l/20zH5zU5/iJGszfwPq6eLjoiC/n/asulA=
github.com/iwind/nftables v0.0.0-20230419014751-9f023a644ad4/go.mod h1:7OLL+86wZKfBnAJxNxmdcZ0ebbgdp/A28fcagx9oJqA=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
@@ -81,8 +79,8 @@ github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTx
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI=
github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -111,8 +109,8 @@ github.com/mssola/useragent v1.0.0 h1:WRlDpXyxHDNfvZaPEut5Biveq86Ze4o4EMffyMxmH5
github.com/mssola/useragent v1.0.0/go.mod h1:hz9Cqz4RXusgg1EdI4Al0INR62kP7aPSRNHnpU+b85Y=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo/v2 v2.12.1 h1:uHNEO1RP2SpuZApSkel9nEh1/Mu+hmQe7Q+Pepg5OYA=
github.com/onsi/ginkgo/v2 v2.12.1/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
github.com/pires/go-proxyproto v0.6.1 h1:EBupykFmo22SDjv4fQVQd2J9NOoLPmyZA/15ldOGkPw=
github.com/pires/go-proxyproto v0.6.1/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
@@ -128,10 +126,10 @@ github.com/qiniu/go-sdk/v7 v7.16.0/go.mod h1:nqoYCNo53ZlGA521RvRethvxUDvXKt4gtYX
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.39.0 h1:AgP40iThFMY0bj8jGxROhw3S0FMGa8ryqsmi9tBH3So=
github.com/quic-go/quic-go v0.39.0/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs=
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.39.2 h1:hmwAf8zAHlvan0Y5PXxeeBFZEW17IW99sXLry8I2kjk=
github.com/quic-go/quic-go v0.39.2/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -171,36 +169,29 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/image v0.7.0 h1:gzS29xtG1J5ybQlv0PuyfE3nmc6R4qB73m6LUUmvFuw=
golang.org/x/image v0.7.0/go.mod h1:nd/q4ef1AKKYl/4kft7g+6UyGbdiqWqTP1ZAbRoV7Rg=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/image v0.13.0 h1:3cge/F/QTkNLauhf2QoE9zp+7sr+ZcL4HnoZmdwg9sg=
golang.org/x/image v0.13.0/go.mod h1:6mmbMOeV28HuMTgA6OSRkdXKYw/t5W9Uwn2Yv1r3Yxk=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -215,24 +206,18 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
@@ -241,20 +226,19 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc h1:XSJ8Vk1SWuNr8S18z1NZSziL0CPIXLCCMDOEFtHBOFc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA=
google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag=
google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo=
google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc=
google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk=
google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -566,7 +566,7 @@ func (this *FileList) UpgradeV3(oldDir string, brokenOnError bool) error {
}
func (this *FileList) maxExpiresAtForMemoryCache(expiresAt int64) int64 {
var maxTimestamp = fasttime.Now().Unix() + 7200
var maxTimestamp = fasttime.Now().Unix() + 3600
if expiresAt > maxTimestamp {
return maxTimestamp
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"net"
"net/url"
"os"
@@ -48,11 +47,10 @@ type FileListDB struct {
deleteByHashStmt *dbs.Stmt // 根据hash删除数据
deleteByHashSQL string
statStmt *dbs.Stmt // 统计
purgeStmt *dbs.Stmt // 清理
deleteAllStmt *dbs.Stmt // 删除所有数据
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
updateAccessWeekStmt *dbs.Stmt // 修改访问日期
statStmt *dbs.Stmt // 统计
purgeStmt *dbs.Stmt // 清理
deleteAllStmt *dbs.Stmt // 删除所有数据
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
}
func NewFileListDB() *FileListDB {
@@ -65,10 +63,10 @@ func (this *FileListDB) Open(dbPath string) error {
this.dbPath = dbPath
// 动态调整Cache值
var cacheSize = 32000
var cacheSize = 512
var memoryGB = utils.SystemMemoryGB()
if memoryGB >= 8 {
cacheSize += 32000 * memoryGB / 8
if memoryGB >= 1 {
cacheSize = 256 * memoryGB
}
// write db
@@ -136,7 +134,7 @@ func (this *FileListDB) Init() error {
return err
}
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt", "accessWeek") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
this.insertStmt, err = this.writeDB.Prepare(this.insertSQL)
if err != nil {
return err
@@ -173,12 +171,7 @@ func (this *FileListDB) Init() error {
return err
}
this.updateAccessWeekStmt, err = this.writeDB.Prepare(`UPDATE "` + this.itemsTableName + `" SET "accessWeek"=? WHERE "hash"=?`)
if err != nil {
return err
}
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "accessWeek" ASC, "id" ASC LIMIT ?`)
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "id" ASC LIMIT ?`)
if err != nil {
return err
}
@@ -213,7 +206,7 @@ func (this *FileListDB) AddSync(hash string, item *Item) error {
item.StaleAt = item.ExpiredAt
}
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, fasttime.Now().Unix(), timeutil.Format("YW"))
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, fasttime.Now().Unix())
if err != nil {
return this.WrapError(err)
}
@@ -309,8 +302,8 @@ func (this *FileListDB) ListHashes(lastId int64) (hashList []string, maxId int64
}
func (this *FileListDB) IncreaseHitAsync(hash string) error {
_, err := this.updateAccessWeekStmt.Exec(timeutil.Format("YW"), hash)
return err
// do nothing
return nil
}
func (this *FileListDB) CleanPrefix(prefix string) error {
@@ -458,9 +451,6 @@ func (this *FileListDB) Close() error {
if this.deleteAllStmt != nil {
_ = this.deleteAllStmt.Close()
}
if this.updateAccessWeekStmt != nil {
_ = this.updateAccessWeekStmt.Close()
}
if this.listOlderItemsStmt != nil {
_ = this.listOlderItemsStmt.Close()
}
@@ -516,8 +506,7 @@ func (this *FileListDB) initTables(times int) error {
"staleAt" integer DEFAULT 0,
"createdAt" integer DEFAULT 0,
"host" varchar(128),
"serverId" integer,
"accessWeek" varchar(6)
"serverId" integer
);
DROP INDEX IF EXISTS "createdAt";
@@ -533,8 +522,6 @@ CREATE INDEX IF NOT EXISTS "hash"
ON "` + this.itemsTableName + `" (
"hash" ASC
);
ALTER TABLE "cacheItems" ADD "accessWeek" varchar(6);
`)
if err != nil {

View File

@@ -4,12 +4,20 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"runtime"
"runtime/debug"
"testing"
"time"
)
func TestFileListDB_ListLFUItems(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var db = caches.NewFileListDB()
defer func() {
@@ -34,6 +42,10 @@ func TestFileListDB_ListLFUItems(t *testing.T) {
}
func TestFileListDB_CleanMatchKey(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var db = caches.NewFileListDB()
defer func() {
@@ -62,6 +74,10 @@ func TestFileListDB_CleanMatchKey(t *testing.T) {
}
func TestFileListDB_CleanMatchPrefix(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var db = caches.NewFileListDB()
defer func() {
@@ -88,3 +104,67 @@ func TestFileListDB_CleanMatchPrefix(t *testing.T) {
t.Fatal(err)
}
}
func TestFileListDB_Memory(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var db = caches.NewFileListDB()
defer func() {
_ = db.Close()
}()
err := db.Open(Tea.Root + "/data/cache-index/p1/db-0.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
if err != nil {
t.Fatal(err)
}
t.Log(db.Total())
// load hashes
var maxId int64
var hashList []string
var before = time.Now()
for i := 0; i < 1_000; i++ {
hashList, maxId, err = db.ListHashes(maxId)
if err != nil {
t.Fatal(err)
}
if len(hashList) == 0 {
t.Log("hashes loaded", time.Since(before).Seconds()*1000, "ms")
break
}
if i%100 == 0 {
t.Log(i)
}
}
runtime.GC()
debug.FreeOSMemory()
//time.Sleep(600 * time.Second)
for i := 0; i < 1_000; i++ {
_, err = db.ListLFUItems(5000)
if err != nil {
t.Fatal(err)
}
if i%100 == 0 {
t.Log(i)
}
}
t.Log("loaded")
runtime.GC()
debug.FreeOSMemory()
time.Sleep(600 * time.Second)
}

View File

@@ -133,7 +133,10 @@ func (this *FileListHashMap) Clean() {
this.lockers[i].Lock()
}
this.m = make([]map[uint64]zero.Zero, HashMapSharding)
// 这里不能简单清空 this.m ,避免导致别的数据无法写入 map 而产生 panic
for i := 0; i < HashMapSharding; i++ {
this.m[i] = map[uint64]zero.Zero{}
}
for i := HashMapSharding - 1; i >= 0; i-- {
this.lockers[i].Unlock()

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
@@ -112,6 +113,25 @@ func TestFileListHashMap_Load(t *testing.T) {
}
}
func TestFileListHashMap_Delete(t *testing.T) {
var a = assert.NewAssertion(t)
var m = caches.NewFileListHashMap()
m.SetIsReady(true)
m.SetIsAvailable(true)
m.Add("a")
a.IsTrue(m.Len() == 1)
m.Delete("a")
a.IsTrue(m.Len() == 0)
}
func TestFileListHashMap_Clean(t *testing.T) {
var m = caches.NewFileListHashMap()
m.SetIsAvailable(true)
m.Clean()
m.Add("a")
}
func Benchmark_BigInt(b *testing.B) {
var hash = stringutil.Md5("123456")
b.ResetTimer()

View File

@@ -7,6 +7,7 @@ import (
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"golang.org/x/sys/unix"
@@ -35,6 +36,9 @@ type Manager struct {
SubDiskDirs []*serverconfigs.CacheDir
MaxMemoryCapacity *shared.SizeCapacity
CountFileStorages int
CountMemoryStorages int
policyMap map[int64]*serverconfigs.HTTPCachePolicy // policyId => []*Policy
storageMap map[int64]StorageInterface // policyId => *Storage
locker sync.RWMutex
@@ -143,6 +147,16 @@ func (this *Manager) UpdatePolicies(newPolicies []*serverconfigs.HTTPCachePolicy
}
}
}
this.CountFileStorages = 0
this.CountFileStorages = 0
for _, storage := range this.storageMap {
_, isFileStorage := storage.(*FileStorage)
this.CountMemoryStorages++
if isFileStorage {
this.CountFileStorages++
}
}
}
// FindPolicy 获取Policy信息
@@ -172,6 +186,11 @@ func (this *Manager) NewStorageWithPolicy(policy *serverconfigs.HTTPCachePolicy)
return nil
}
// StorageMap 获取已有的存储对象
func (this *Manager) StorageMap() map[int64]StorageInterface {
return this.storageMap
}
// TotalDiskSize 消耗的磁盘尺寸
func (this *Manager) TotalDiskSize() int64 {
this.locker.RLock()
@@ -272,3 +291,17 @@ func (this *Manager) ScanGarbageCaches(callback func(path string) error) error {
}
return nil
}
// MaxSystemMemoryBytesPerStorage 计算单个策略能使用的系统最大内存
func (this *Manager) MaxSystemMemoryBytesPerStorage() int64 {
var count = this.CountMemoryStorages
if count < 1 {
count = 1
}
var resultBytes = int64(utils.SystemMemoryBytes()) / 3 / int64(count) // 1/3 of the system memory
if resultBytes < 1<<30 {
resultBytes = 1 << 30
}
return resultBytes
}

View File

@@ -1,8 +1,9 @@
package caches
package caches_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/Tea"
"testing"
)
@@ -10,7 +11,7 @@ import (
func TestManager_UpdatePolicies(t *testing.T) {
{
var policies = []*serverconfigs.HTTPCachePolicy{}
SharedManager.UpdatePolicies(policies)
caches.SharedManager.UpdatePolicies(policies)
printManager(t)
}
@@ -38,7 +39,7 @@ func TestManager_UpdatePolicies(t *testing.T) {
},
},
}
SharedManager.UpdatePolicies(policies)
caches.SharedManager.UpdatePolicies(policies)
printManager(t)
}
@@ -66,7 +67,7 @@ func TestManager_UpdatePolicies(t *testing.T) {
},
},
}
SharedManager.UpdatePolicies(policies)
caches.SharedManager.UpdatePolicies(policies)
printManager(t)
}
}
@@ -80,8 +81,8 @@ func TestManager_ChangePolicy_Memory(t *testing.T) {
Capacity: &shared.SizeCapacity{Count: 1, Unit: shared.SizeCapacityUnitGB},
},
}
SharedManager.UpdatePolicies(policies)
SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{
caches.SharedManager.UpdatePolicies(policies)
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{
{
Id: 1,
Type: serverconfigs.CachePolicyStorageMemory,
@@ -102,8 +103,8 @@ func TestManager_ChangePolicy_File(t *testing.T) {
Capacity: &shared.SizeCapacity{Count: 1, Unit: shared.SizeCapacityUnitGB},
},
}
SharedManager.UpdatePolicies(policies)
SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{
caches.SharedManager.UpdatePolicies(policies)
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{
{
Id: 1,
Type: serverconfigs.CachePolicyStorageFile,
@@ -115,10 +116,17 @@ func TestManager_ChangePolicy_File(t *testing.T) {
})
}
func TestManager_MaxSystemMemoryBytesPerStorage(t *testing.T) {
for i := 0; i < 100; i++ {
caches.SharedManager.CountMemoryStorages = i
t.Log(i, caches.SharedManager.MaxSystemMemoryBytesPerStorage()>>30, "GB")
}
}
func printManager(t *testing.T) {
t.Log("===manager==")
t.Log("storage:")
for _, storage := range SharedManager.storageMap {
for _, storage := range caches.SharedManager.StorageMap() {
t.Log(" storage:", storage.Policy().Id)
}
t.Log("===============")

View File

@@ -15,6 +15,7 @@ import (
)
const (
enableFragmentPool = false
minMemoryFragmentPoolItemSize = 8 << 10
maxMemoryFragmentPoolItemSize = 128 << 20
maxItemsInMemoryFragmentPoolBucket = 1024

View File

@@ -517,7 +517,7 @@ func (this *MemoryStorage) flushItem(key string) {
_ = this.Delete(key)
// 重用内存,前提是确保内存不再被引用
if ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
if enableFragmentPool && ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
SharedFragmentMemoryPool.Put(item.BodyValue)
}
}()
@@ -584,8 +584,7 @@ func (this *MemoryStorage) flushItem(key string) {
}
func (this *MemoryStorage) memoryCapacityBytes() int64 {
var maxSystemBytes = int64(utils.SystemMemoryBytes()) / 3 // 1/3 of the system memory
var maxSystemBytes = SharedManager.MaxSystemMemoryBytesPerStorage()
if this.policy == nil {
return maxSystemBytes
}
@@ -612,7 +611,6 @@ func (this *MemoryStorage) memoryCapacityBytes() int64 {
}
}
// 1/4 of the system memory
return maxSystemBytes
}

View File

@@ -32,7 +32,9 @@ func NewMemoryWriter(memoryStorage *MemoryStorage, key string, expiredAt int64,
ModifiedAt: fasttime.Now().Unix(),
Status: status,
}
if expectedBodySize > 0 && expectedBodySize <= maxMemoryFragmentPoolItemSize {
if enableFragmentPool &&
expectedBodySize > 0 &&
expectedBodySize <= maxMemoryFragmentPoolItemSize {
bodyBytes, ok := SharedFragmentMemoryPool.Get(expectedBodySize) // try to reuse memory
if ok {
valueItem.BodyValue = bodyBytes
@@ -168,7 +170,8 @@ func (this *MemoryWriter) Discard() error {
this.storage.locker.Lock()
delete(this.storage.valuesMap, this.hash)
if this.item != nil &&
if enableFragmentPool &&
this.item != nil &&
!this.item.isReferring &&
cap(this.item.BodyValue) >= minMemoryFragmentPoolItemSize {
SharedFragmentMemoryPool.Put(this.item.BodyValue)

View File

@@ -1,4 +1,5 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus || !linux
package compressions

View File

@@ -1,4 +1,5 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus || !linux
package compressions
@@ -27,7 +28,7 @@ func newBrotliWriter(writer io.Writer, level int) (*BrotliWriter, error) {
return &BrotliWriter{
writer: brotli.NewWriterOptions(writer, brotli.WriterOptions{
Quality: level,
LGWin: 13, // TODO 在全局设置里可以设置此值
LGWin: 14, // TODO 在全局设置里可以设置此值
}),
level: level,
}, nil

View File

@@ -19,6 +19,10 @@ func NewZSTDWriter(writer io.Writer, level int) (Writer, error) {
}
func newZSTDWriter(writer io.Writer, level int) (Writer, error) {
if level < 0 {
level = 0
}
var zstdLevel = zstd.EncoderLevelFromZstd(level)
zstdWriter, err := zstd.NewWriter(writer, zstd.WithEncoderLevel(zstdLevel))

View File

@@ -9,6 +9,24 @@ import (
"testing"
)
func TestNewZSTDWriter_Level0(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 0)
if err != nil {
t.Fatal(err)
}
var originData = []byte(strings.Repeat("Hello", 1024))
_, err = writer.Write(originData)
if err != nil {
t.Fatal(err)
}
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log("origin data:", len(originData), "result:", buf.Len())
}
func TestNewZSTDWriter(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 10)

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.2.10"
Version = "1.3.2"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -90,13 +90,13 @@ func (this *APIStream) loop() error {
break
}
message, err := nodeStream.Recv()
if err != nil {
message, streamErr := nodeStream.Recv()
if streamErr != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return err
return streamErr
}
// 处理消息

View File

@@ -305,7 +305,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
// 非TLS设置为两倍防止误封
minAttempts = 2 * minAttempts
}
if result >= types.Uint64(minAttempts) {
if result >= types.Uint32(minAttempts) {
var timeout = synFloodConfig.TimeoutSeconds
if timeout <= 0 {
timeout = 600

View File

@@ -85,6 +85,8 @@ type HTTPRequest struct {
isAttack bool // 是否是攻击请求
requestBodyData []byte // 读取的Body内容
isWebsocketResponse bool // 是否为Websocket响应非请求
// WAF相关
firewallPolicyId int64
firewallRuleGroupId int64
@@ -410,6 +412,8 @@ func (this *HTTPRequest) doEnd() {
var countAttacks int64 = 0
var attackBytes int64 = 0
var countWebsocketConnections int64 = 0
if this.isCached {
countCached = 1
cachedBytes = totalBytes
@@ -421,8 +425,11 @@ func (this *HTTPRequest) doEnd() {
attackBytes = totalBytes
}
}
if this.isWebsocketResponse {
countWebsocketConnections = 1
}
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, countWebsocketConnections, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
// 指标
if metrics.SharedManager.HasHTTPMetrics() {

View File

@@ -3,6 +3,7 @@ package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -130,7 +131,22 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var tags = []string{}
// 检查是否有缓存
var key = this.Format(this.cacheRef.Key)
var key string
if this.web.Cache.Key != nil && this.web.Cache.Key.IsOn && len(this.web.Cache.Key.Host) > 0 {
key = configutils.ParseVariables(this.cacheRef.Key, func(varName string) (value string) {
switch varName {
case "scheme":
return this.web.Cache.Key.Scheme
case "host":
return this.web.Cache.Key.Host
default:
return this.Format("${" + varName + "}")
}
})
} else {
key = this.Format(this.cacheRef.Key)
}
if len(key) == 0 {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, empty key"
@@ -274,7 +290,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
if err != nil {
if err == caches.ErrNotFound {
if errors.Is(err, caches.ErrNotFound) {
// 移除请求中的 If-None-Match 和 If-Modified-Since防止源站返回304而无法缓存
if this.reverseProxy != nil {
this.RawReq.Header.Del("If-None-Match")
this.RawReq.Header.Del("If-Modified-Since")
}
// cache相关变量
this.varMapping["cache.status"] = "MISS"
@@ -365,24 +387,24 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
// ETag
// 这里强制设置ETag如果先前源站设置了ETag将会被覆盖避免因为源站的ETag导致源站返回304 Not Modified
var respHeader = this.writer.Header()
var eTag = ""
var eTag = respHeader.Get("ETag")
var lastModifiedAt = reader.LastModified()
if lastModifiedAt > 0 {
if len(tags) > 0 {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_" + strings.Join(tags, "_") + "\""
} else {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
}
respHeader.Del("Etag")
if !isPartialCache {
respHeader["ETag"] = []string{eTag}
if len(eTag) == 0 {
if lastModifiedAt > 0 {
if len(tags) > 0 {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_" + strings.Join(tags, "_") + "\""
} else {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
}
respHeader.Del("Etag")
if !isPartialCache {
respHeader["ETag"] = []string{eTag}
}
}
}
// 支持 Last-Modified
// 这里强制设置Last-Modified如果先前源站设置了Last-Modified将会被覆盖避免因为源站的Last-Modified导致源站返回304 Not Modified
var modifiedTime = ""
if lastModifiedAt > 0 {
modifiedTime = time.Unix(utils.GMTUnixTime(lastModifiedAt), 0).Format("Mon, 02 Jan 2006 15:04:05") + " GMT"
@@ -490,7 +512,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if err != nil {
this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange {
if errors.Is(err, caches.ErrInvalidRange) {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true

View File

@@ -25,6 +25,13 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
continue
}
if len(u.ExceptDomains) > 0 && configutils.MatchDomains(u.ExceptDomains, this.ReqHost) {
continue
}
if len(u.OnlyDomains) > 0 && !configutils.MatchDomains(u.OnlyDomains, this.ReqHost) {
continue
}
var status = u.Status
if status <= 0 {
if searchEngineRegex.MatchString(this.RawReq.UserAgent()) {

View File

@@ -2,7 +2,6 @@ package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
@@ -46,9 +45,13 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
}
func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, status int) (shouldStop bool) {
var url = this.URL()
for _, page := range pages {
if !page.MatchURL(url) {
continue
}
if page.Match(status) {
if len(page.BodyType) == 0 || page.BodyType == shared.BodyTypeURL {
if len(page.BodyType) == 0 || page.BodyType == serverconfigs.HTTPPageBodyTypeURL {
if urlSchemeRegexp.MatchString(page.URL) {
var newStatus = page.NewStatus
if newStatus <= 0 {
@@ -115,7 +118,7 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
}
return true
} else if page.BodyType == shared.BodyTypeHTML {
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
// 这里需要实现设置Status因为在Format()中可以获取${status}等变量
if page.NewStatus > 0 {
this.writer.statusCode = page.NewStatus
@@ -147,6 +150,18 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
this.writer.SetOk()
}
return true
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
var newURL = page.URL
if len(newURL) == 0 {
newURL = "/"
}
if page.NewStatus > 0 && httpStatusIsRedirect(page.NewStatus) {
httpRedirect(this.writer, this.RawReq, newURL, page.NewStatus)
} else {
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
}
this.writer.SetOk()
return true
}
}
}

View File

@@ -27,9 +27,10 @@ func (this *HTTPRequest) doReverseProxy() {
var failedOriginIds []int64
var failedLnNodeIds []int64
var failStatusCode int
for i := 0; i < retries; i++ {
originId, lnNodeId, shouldRetry := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1)
originId, lnNodeId, shouldRetry := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1, &failStatusCode)
if !shouldRetry {
break
}
@@ -43,7 +44,7 @@ func (this *HTTPRequest) doReverseProxy() {
}
// 请求源站
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool) (originId int64, lnNodeId int64, shouldRetry bool) {
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool, failStatusCode *int) (originId int64, lnNodeId int64, shouldRetry bool) {
// 对URL的处理
var stripPrefix = this.reverseProxy.StripPrefix
var requestURI = this.reverseProxy.RequestURI
@@ -91,6 +92,10 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
if origin == nil {
origin = this.reverseProxy.NextOrigin(requestCall)
if origin != nil && origin.Id > 0 && (*failStatusCode >= 403 && *failStatusCode <= 404) && lists.ContainsInt64(failedOriginIds, origin.Id) {
this.writeCode(*failStatusCode, "", "")
return
}
}
requestCall.CallResponseCallbacks(this.writer)
if origin == nil {
@@ -376,11 +381,11 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
return
}
// 50x
// 40x && 50x
*failStatusCode = resp.StatusCode
if resp != nil &&
resp.StatusCode >= 500 &&
resp.StatusCode < 510 &&
this.reverseProxy.Retry50X &&
((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
(originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) &&
!isLastRetry {
if resp.Body != nil {
@@ -429,7 +434,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// Page optimization
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
err := this.web.Optimization.FilterResponse(resp)
err := this.web.Optimization.FilterResponse(this.URL(), resp)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Page Optimization: Fail to read content from origin", "内容优化:从源站读取内容失败", false)
return

View File

@@ -1,7 +1,7 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
@@ -18,7 +18,7 @@ func (this *HTTPRequest) doShutdown() {
return
}
if len(shutdown.BodyType) == 0 || shutdown.BodyType == shared.BodyTypeURL {
if len(shutdown.BodyType) == 0 || shutdown.BodyType == serverconfigs.HTTPPageBodyTypeURL {
// URL
if urlSchemeRegexp.MatchString(shutdown.URL) {
this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status, true)
@@ -80,7 +80,7 @@ func (this *HTTPRequest) doShutdown() {
} else {
this.writer.SetOk()
}
} else if shutdown.BodyType == shared.BodyTypeHTML {
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
@@ -98,5 +98,17 @@ func (this *HTTPRequest) doShutdown() {
} else {
this.writer.SetOk()
}
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
var newURL = shutdown.URL
if len(newURL) == 0 {
newURL = "/"
}
if shutdown.Status > 0 && httpStatusIsRedirect(shutdown.Status) {
httpRedirect(this.writer, this.RawReq, newURL, shutdown.Status)
} else {
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
}
this.writer.SetOk()
}
}

View File

@@ -96,6 +96,8 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
var isDefendMode = firewallPolicy.Mode == firewallconfigs.FirewallModeDefend
// 检查IP白名单
var remoteAddrs []string
if len(this.remoteAddr) > 0 {
@@ -122,7 +124,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
// 检查IP黑名单
if firewallPolicy.Mode == firewallconfigs.FirewallModeDefend {
if isDefendMode {
for _, ref := range inbound.AllDenyListRefs() {
if ref.IsOn && ref.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(ref.ListId)
@@ -161,19 +163,20 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
// 检查地区封禁
if firewallPolicy.Mode == firewallconfigs.FirewallModeDefend {
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
var regionConfig = firewallPolicy.Inbound.Region
if regionConfig.IsNotEmpty() {
for _, remoteAddr := range remoteAddrs {
var result = iplib.LookupIP(remoteAddr)
if result != nil && result.IsOk() {
var currentURL = this.URL()
if regionConfig.MatchCountryURL(currentURL) {
// 检查国家/地区级别封禁
if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) {
this.firewallPolicyId = firewallPolicy.Id
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
var regionConfig = firewallPolicy.Inbound.Region
if regionConfig.IsNotEmpty() {
for _, remoteAddr := range remoteAddrs {
var result = iplib.LookupIP(remoteAddr)
if result != nil && result.IsOk() {
var currentURL = this.URL()
if regionConfig.MatchCountryURL(currentURL) {
// 检查国家/地区级别封禁
if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) {
this.firewallPolicyId = firewallPolicy.Id
if isDefendMode {
var promptHTML string
if len(regionConfig.CountryHTML) > 0 {
promptHTML = regionConfig.CountryHTML
@@ -193,23 +196,27 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// 延时返回,避免攻击
time.Sleep(1 * time.Second)
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyCountry")
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyCountry")
}
if isDefendMode {
return true, false
}
}
}
if regionConfig.MatchProvinceURL(currentURL) {
// 检查省份封禁
if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) {
this.firewallPolicyId = firewallPolicy.Id
if regionConfig.MatchProvinceURL(currentURL) {
// 检查省份封禁
if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) {
this.firewallPolicyId = firewallPolicy.Id
if isDefendMode {
var promptHTML string
if len(regionConfig.ProvinceHTML) > 0 {
promptHTML = regionConfig.ProvinceHTML
@@ -229,14 +236,16 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// 延时返回,避免攻击
time.Sleep(1 * time.Second)
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyProvince")
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyProvince")
}
if isDefendMode {
return true, false
}
}
@@ -257,7 +266,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
this.wafHasRequestBody = true
}
@@ -307,7 +316,7 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
}
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blocked {
return true
}
@@ -315,7 +324,7 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blocked {
return true
}
@@ -469,3 +478,10 @@ func (this *HTTPRequest) WAFMaxRequestSize() int64 {
func (this *HTTPRequest) DisableAccessLog() {
this.disableLog = true
}
// DisableStat 停用统计
func (this *HTTPRequest) DisableStat() {
if this.web != nil {
this.web.StatRef = nil
}
}

View File

@@ -61,6 +61,9 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
}
// 标记
this.isWebsocketResponse = true
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
var newRequestOrigin = this.web.Websocket.RequestOrigin
@@ -77,7 +80,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
// 连接源站
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
if isLastRetry {

View File

@@ -11,7 +11,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
@@ -34,22 +33,19 @@ import (
"net/textproto"
"os"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
)
var webpMaxBufferSize int64 = 1_000_000_000
var webpTotalBufferSize int64 = 0
var webpIgnoreURLSet = setutils.NewFixedSet(131072)
var webPThreads int32
var webPMaxThreads int32 = 1
var webPIgnoreURLSet = setutils.NewFixedSet(131072)
func init() {
if !teaconst.IsMain {
return
}
var systemMemory = utils.SystemMemoryGB() / 8
if systemMemory > 0 {
webpMaxBufferSize = int64(systemMemory) << 30
webPMaxThreads = int32(runtime.NumCPU() / 4)
if webPMaxThreads < 1 {
webPMaxThreads = 1
}
}
@@ -80,6 +76,7 @@ type HTTPWriter struct {
// WebP
webpIsEncoding bool
webpOriginContentType string
webpQuality int
// Compression
compressionConfig *serverconfigs.HTTPCompressionConfig
@@ -483,8 +480,8 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
contentTypeWritten = true
}
err := cacheWriter.WriteAt(start, data)
if err != nil {
writeErr := cacheWriter.WriteAt(start, data)
if writeErr != nil {
hasError = true
this.cacheIsFinished = false
}
@@ -531,6 +528,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
if policy.RequireCache && this.req.cacheRef == nil {
return
}
this.webpQuality = policy.Quality
// 限制最小和最大尺寸
// TODO 需要将reader修改为LimitReader
@@ -550,7 +548,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
this.req.web.WebP.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) &&
this.req.web.WebP.MatchAccept(this.req.requestHeader("Accept")) {
// 检查是否已经因为尺寸过大而忽略
if webpIgnoreURLSet.Has(this.req.URL()) {
if webPIgnoreURLSet.Has(this.req.URL()) {
return
}
@@ -560,8 +558,8 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
return
}
// 检查内存
if atomic.LoadInt64(&webpTotalBufferSize) >= webpMaxBufferSize {
// 检查当前是否正在转换
if atomic.LoadInt32(&webPThreads) >= webPMaxThreads {
return
}
@@ -622,7 +620,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
return
}
if this.compressionConfig.Level <= 0 {
if this.compressionConfig.Level < 0 {
return
}
@@ -1020,6 +1018,11 @@ func (this *HTTPWriter) calculateStaleLife() int {
func (this *HTTPWriter) finishWebP() {
// 处理WebP
if this.webpIsEncoding {
atomic.AddInt32(&webPThreads, 1)
defer func() {
atomic.AddInt32(&webPThreads, -1)
}()
var webpCacheWriter caches.Writer
// 准备WebP Cache
@@ -1080,7 +1083,7 @@ func (this *HTTPWriter) finishWebP() {
if isGif {
gifImage, err = gif.DecodeAll(reader)
if gifImage != nil && (gifImage.Config.Width > gowebp.WebPMaxDimension || gifImage.Config.Height > gowebp.WebPMaxDimension) {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
} else {
@@ -1088,7 +1091,7 @@ func (this *HTTPWriter) finishWebP() {
if imageData != nil {
var bound = imageData.Bounds()
if bound.Max.X > gowebp.WebPMaxDimension || bound.Max.Y > gowebp.WebPMaxDimension {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
}
@@ -1096,19 +1099,21 @@ func (this *HTTPWriter) finishWebP() {
if err != nil {
// 发生了错误终止处理
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
var totalBytes = reader.TotalBytes()
atomic.AddInt64(&webpTotalBufferSize, totalBytes)
defer func() {
atomic.AddInt64(&webpTotalBufferSize, -totalBytes)
}()
var f = types.Float32(this.req.web.WebP.Quality)
if f > 100 {
f = 100
var f = types.Float32(this.webpQuality)
if f <= 0 || f > 100 {
if this.size > (8<<20) || this.size <= 0 {
f = 30
} else if this.size > (1 << 20) {
f = 50
} else if this.size > (128 << 10) {
f = 60
} else {
f = 75
}
}
if imageData != nil {

View File

@@ -46,7 +46,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
}
}
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo))
if err != nil {
return nil, err
}
@@ -69,7 +69,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
}
}
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo))
if err != nil {
return nil, err
}
@@ -85,7 +85,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
}
// 根据域名匹配证书
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
var group = this.Group
if group == nil {
@@ -99,7 +99,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
// 如果域名为空,则取第一个
// 通常域名为空是因为是直接通过IP访问的
if len(domain) == 0 {
if len(domains) == 0 {
if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
return nil, nil, errors.New("no tls server name matched")
}
@@ -116,9 +116,25 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
}
return nil, nil, errors.New("no tls server name found")
}
var firstDomain = domains[0]
// 通过网站域名配置匹配
server, _ := this.findNamedServer(domain)
var server *serverconfigs.ServerConfig
var matchedDomain string
for _, domain := range domains {
server, _ = this.findNamedServer(domain, true)
if server != nil {
matchedDomain = domain
break
}
}
if server == nil {
server, _ = this.findNamedServer(firstDomain, false)
if server != nil {
matchedDomain = firstDomain
}
}
if server == nil {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
// 此功能仅为了兼容以往版本v1.0.4),不应该作为常态启用
@@ -127,14 +143,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
continue
}
cert, ok := searchingServer.SSLPolicy().MatchDomain(domain)
cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain)
if ok {
return searchingServer.SSLPolicy(), cert, nil
}
}
}
return nil, nil, errors.New("no server found for '" + domain + "'")
return nil, nil, errors.New("no server found for '" + firstDomain + "'")
}
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
@@ -144,32 +160,32 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
continue
}
cert, ok := searchingServer.SSLPolicy().MatchDomain(domain)
cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain)
if ok {
return searchingServer.SSLPolicy(), cert, nil
}
}
}
return nil, nil, errors.New("no cert found for '" + domain + "'")
return nil, nil, errors.New("no cert found for '" + matchedDomain + "'")
}
// 证书是否匹配
var sslConfig = server.SSLPolicy()
cert, ok := sslConfig.MatchDomain(domain)
cert, ok := sslConfig.MatchDomain(matchedDomain)
if ok {
return sslConfig, cert, nil
}
if len(sslConfig.Certs) == 0 {
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id), "", nil)
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil)
}
return sslConfig, sslConfig.FirstCert(), nil
}
// 根据域名来查找匹配的域名
func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) {
serverConfig, serverName = this.findNamedServerMatched(name)
if serverConfig != nil {
return
@@ -194,18 +210,22 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
}
}
if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || !utils.IsWildIP(name)) {
if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || (!utils.IsWildIP(name) || globalServerConfig.HTTPAll.NodeIPShowPage)) {
return
}
// 如果没有找到,则匹配到第一个
var group = this.Group
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 0 {
return nil, ""
if !exactly {
// 如果没有找到,则匹配到第一个
var group = this.Group
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 0 {
return nil, ""
}
return currentServers[0], name
}
return currentServers[0], name
return
}
// 严格查找域名
@@ -234,16 +254,23 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
}
// 从Hello信息中获取服务名称
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
var serverName = clientInfo.ServerName
if len(serverName) == 0 && clientInfo.Conn != nil {
func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) {
if len(clientInfo.ServerName) != 0 {
serverNames = append(serverNames, clientInfo.ServerName)
return
}
if clientInfo.Conn != nil {
var localAddr = clientInfo.Conn.LocalAddr()
if localAddr != nil {
tcpAddr, ok := localAddr.(*net.TCPAddr)
if ok {
serverName = tcpAddr.IP.String()
serverNames = append(serverNames, tcpAddr.IP.String())
}
}
}
return serverName
serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
return
}

View File

@@ -107,15 +107,23 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
// ServerHTTP 处理HTTP请求
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
if len(rawReq.Host) > 253 {
http.Error(rawWriter, "Host too long.", http.StatusBadRequest)
time.Sleep(1 * time.Second) // make connection slow down
return
}
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil && !globalServerConfig.HTTPAll.SupportsLowVersionHTTP && (rawReq.ProtoMajor < 1 /** 0.x **/ || (rawReq.ProtoMajor == 1 && rawReq.ProtoMinor == 0 /** 1.0 **/)) {
http.Error(rawWriter, rawReq.Proto+" request is not supported.", http.StatusBadRequest)
time.Sleep(1 * time.Second) // make connection slow down
return
}
// 不支持Connect
if rawReq.Method == http.MethodConnect {
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
time.Sleep(1 * time.Second) // make connection slow down
return
}
@@ -154,7 +162,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
domain = reqHost
}
server, serverName := this.findNamedServer(domain)
server, serverName := this.findNamedServer(domain, false)
if server == nil {
if server == nil {
// 增加默认的一个服务

View File

@@ -47,9 +47,13 @@ func (this *TCPListener) Serve() error {
atomic.AddInt64(&this.countActiveConnections, 1)
go func(conn net.Conn) {
err = this.handleConn(conn)
var server = this.Group.FirstServer()
if server == nil {
return
}
err = this.handleConn(server, conn)
if err != nil {
remotelogs.Error("TCP_LISTENER", err.Error())
remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
}
atomic.AddInt64(&this.countActiveConnections, -1)
}(conn)
@@ -63,8 +67,7 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Reset()
}
func (this *TCPListener) handleConn(conn net.Conn) error {
var server = this.Group.FirstServer()
func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
if server == nil {
return errors.New("no server available")
}
@@ -132,14 +135,14 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
@@ -194,7 +197,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 记录流量
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {

View File

@@ -370,7 +370,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 统计
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
// 处理ControlMessage
@@ -401,7 +401,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 记录流量和带宽
if server != nil {
// 流量
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// 带宽
var userPlanId int64

View File

@@ -146,12 +146,12 @@ func (this *Node) Start() {
remotelogs.Println("NODE", "init config ...")
err = this.syncConfig(0)
if err != nil {
_, err := nodeconfigs.SharedNodeConfig()
_, err = nodeconfigs.SharedNodeConfig()
if err != nil {
// 无本地数据时,会尝试多次读取
tryTimes := 0
for {
err := this.syncConfig(0)
err = this.syncConfig(0)
if err != nil {
tryTimes++
@@ -777,9 +777,19 @@ func (this *Node) listenSock() error {
_ = cmd.ReplyOk()
}
case "gc":
var before = time.Now()
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
var costSeconds = time.Since(before).Seconds()
var gcStats = &debug.GCStats{}
debug.ReadGCStats(gcStats)
_ = cmd.Reply(&gosock.Command{
Params: map[string]any{
"pauseMS": gcStats.PauseTotal.Seconds() * 1000,
"costMS": costSeconds * 1000,
},
})
case "reload":
err := this.syncConfig(0)
if err != nil {
@@ -1039,7 +1049,7 @@ func (this *Node) reloadServer() {
for serverId, serverConfig := range updatingServerMap {
if serverConfig != nil {
if countUpdatingServers < maxPrintServers {
remotelogs.Debug("NODE", "load server '"+types.String(serverId)+"'")
remotelogs.Debug("NODE", "reload server '"+types.String(serverId)+"'")
}
newNodeConfig.AddServer(serverConfig)
} else {

View File

@@ -223,6 +223,7 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
// 当前TeaWeb所在的fs
var rootFS = ""
var rootTotal = uint64(0)
var totalUsed = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
@@ -230,6 +231,7 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
usage, _ := disk.Usage(p.Mountpoint)
if usage != nil {
rootTotal = usage.Total
totalUsed = usage.Used
}
break
}
@@ -237,7 +239,6 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
}
var total = rootTotal
var totalUsage = uint64(0)
var maxUsage = float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
@@ -256,16 +257,16 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
total += usage.Total
}
totalUsage += usage.Used
if usage.UsedPercent >= maxUsage {
maxUsage = usage.UsedPercent
status.DiskMaxUsagePartition = partition.Mountpoint
totalUsed += usage.Used
if usage.UsedPercent >= maxUsage {
maxUsage = usage.UsedPercent
status.DiskMaxUsagePartition = partition.Mountpoint
}
}
}
status.DiskTotal = total
if total > 0 {
status.DiskUsage = float64(totalUsage) / float64(total)
status.DiskUsage = float64(totalUsed) / float64(total)
}
status.DiskMaxUsage = maxUsage / 100

View File

@@ -17,6 +17,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
@@ -97,6 +98,10 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
err = this.notifyPlusChange()
case "toaChanged":
err = this.execTOAChangedTask()
case "networkSecurityPolicyChanged":
err = this.execNetworkSecurityPolicyChangedTask(rpcClient)
case "webPPolicyChanged":
err = this.execWebPPolicyChangedTask(rpcClient)
default:
// 特殊任务
if strings.HasPrefix(task.Type, "ipListDeleted") { // 删除IP名单
@@ -296,7 +301,7 @@ func (this *Node) execUpdatingServersTask(rpcClient *rpc.RPCClient) error {
// 删除IP名单
func (this *Node) execDeleteIPList(taskType string) error {
optionsString, ok := strings.CutPrefix(taskType, "ipListDeleted@")
optionsString, ok := utils.CutPrefix(taskType, "ipListDeleted@")
if !ok {
return errors.New("invalid task type '" + taskType + "'")
}
@@ -322,6 +327,34 @@ func (this *Node) execDeleteIPList(taskType string) error {
return nil
}
// WebP策略变更
func (this *Node) execWebPPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating webp policies ...")
resp, err := rpcClient.NodeRPC.FindNodeWebPPolicies(rpcClient.Context(), &pb.FindNodeWebPPoliciesRequest{})
if err != nil {
return err
}
var webPPolicyMap = map[int64]*nodeconfigs.WebPImagePolicy{}
for _, policy := range resp.WebPPolicies {
if len(policy.WebPPolicyJSON) > 0 {
var webPPolicy = nodeconfigs.NewWebPImagePolicy()
err = json.Unmarshal(policy.WebPPolicyJSON, webPPolicy)
if err != nil {
remotelogs.Error("NODE", "decode webp policy failed: "+err.Error())
continue
}
err = webPPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize webp policy failed: "+err.Error())
continue
}
webPPolicyMap[policy.NodeClusterId] = webPPolicy
}
}
sharedNodeConfig.UpdateWebPImagePolicies(webPPolicyMap)
return nil
}
// 标记任务完成
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
if taskId <= 0 {

View File

@@ -29,3 +29,8 @@ func (this *Node) execHTTPPagesPolicyChangedTask(rpcClient *rpc.RPCClient) error
// stub
return nil
}
func (this *Node) execNetworkSecurityPolicyChangedTask(rpcClient *rpc.RPCClient) error {
// stub
return nil
}

View File

@@ -4,7 +4,7 @@ package re
type RuneMap map[rune]*RuneTree
func (this *RuneMap) Lookup(s string, caseInsensitive bool) bool {
func (this RuneMap) Lookup(s string, caseInsensitive bool) bool {
return this.lookup([]rune(s), caseInsensitive, 0)
}

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"net/url"
"sync"
@@ -240,12 +241,15 @@ func (this *RPCClient) init() error {
grpc.MaxCallSendMsgSize(512<<20),
grpc.UseCompressor(gzip.Name),
)
var keepaliveParams = grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
})
if u.Scheme == "http" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions)
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions, keepaliveParams)
} else if u.Scheme == "https" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), callOptions)
})), callOptions, keepaliveParams)
} else {
return errors.New("parse endpoint failed: invalid scheme '" + u.Scheme + "'")
}

View File

@@ -57,12 +57,13 @@ type BandwidthStat struct {
MaxBytes int64 `json:"maxBytes"`
TotalBytes int64 `json:"totalBytes"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
UserPlanId int64 `json:"userPlanId"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
CountWebsocketConnections int64 `json:"countWebsocketConnections"`
UserPlanId int64 `json:"userPlanId"`
}
// BandwidthStatManager 服务带宽统计
@@ -142,20 +143,21 @@ func (this *BandwidthStatManager) Loop() error {
}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
})
delete(this.m, key)
}
@@ -231,7 +233,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, s
}
// AddTraffic 添加请求数据
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64) {
var now = fasttime.Now()
var day = now.Ymd()
var timeAt = now.Round5Hi()
@@ -245,6 +247,7 @@ func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64,
stat.CountCachedRequests += countCachedRequests
stat.CountAttackRequests += countAttacks
stat.AttackBytes += attackBytes
stat.CountWebsocketConnections += countWebsocketConnections
}
this.locker.Unlock()
}

View File

@@ -53,19 +53,20 @@ func BenchmarkBandwidthStatManager_Slice(b *testing.B) {
for j := 0; j < 100; j++ {
var stat = &stats.BandwidthStat{}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
NodeRegionId: 1,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
NodeRegionId: 1,
})
}
_ = pbStats

View File

@@ -106,13 +106,13 @@ func (this *TrafficStatManager) Start() {
}
// Add 添加流量
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64, checkingTrafficLimit bool, planId int64) {
if serverId == 0 {
return
}
// 添加到带宽
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes)
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes, countWebsocketConnections)
if bytes == 0 && countRequests == 0 {
return
@@ -142,24 +142,26 @@ func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string,
item.PlanId = planId
// 单个域名流量
var domainKey = types.String(timestamp) + "@" + domain
serverDomainMap, ok := this.domainsMap[serverId]
if !ok {
serverDomainMap = map[string]*TrafficItem{}
this.domainsMap[serverId] = serverDomainMap
}
if len(domain) < 128 {
var domainKey = types.String(timestamp) + "@" + domain
serverDomainMap, ok := this.domainsMap[serverId]
if !ok {
serverDomainMap = map[string]*TrafficItem{}
this.domainsMap[serverId] = serverDomainMap
}
domainItem, ok := serverDomainMap[domainKey]
if !ok {
domainItem = &TrafficItem{}
serverDomainMap[domainKey] = domainItem
domainItem, ok := serverDomainMap[domainKey]
if !ok {
domainItem = &TrafficItem{}
serverDomainMap[domainKey] = domainItem
}
domainItem.Bytes += bytes
domainItem.CachedBytes += cachedBytes
domainItem.CountRequests += countRequests
domainItem.CountCachedRequests += countCachedRequests
domainItem.CountAttackRequests += countAttacks
domainItem.AttackBytes += attackBytes
}
domainItem.Bytes += bytes
domainItem.CachedBytes += cachedBytes
domainItem.CountRequests += countRequests
domainItem.CountCachedRequests += countCachedRequests
domainItem.CountAttackRequests += countAttacks
domainItem.AttackBytes += attackBytes
this.locker.Unlock()
}

View File

@@ -11,7 +11,7 @@ import (
func TestTrafficStatManager_Add(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, 0, false, 0)
}
t.Log(manager.itemMap)
}
@@ -19,7 +19,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
func TestTrafficStatManager_Upload(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, 0, false, 0)
}
err := manager.Upload()
if err != nil {
@@ -36,7 +36,7 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, 0, false, 0)
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"sync"
"sync/atomic"
"time"
@@ -138,7 +139,7 @@ func (this *Stat) IsGood(category string) bool {
return true
}
if item.countCached > countSamples && item.timestamp < fasttime.Now().Unix()-600 /** 10 minutes ago **/ {
if item.countCached > countSamples && (Tea.IsTesting() || item.timestamp < fasttime.Now().Unix()-600) /** 10 minutes ago **/ {
var isGood = item.countHits*100/item.countCached >= this.goodRatio
if isGood {
item.isGood = true

View File

@@ -58,10 +58,10 @@ func TestNewStat(t *testing.T) {
{
var stat = cachehits.NewStat(5)
for i := 0; i < 10001; i++ {
for i := 0; i < 100001; i++ {
stat.IncreaseCached("a")
}
for i := 0; i < 499; i++ {
for i := 0; i < 4999; i++ {
stat.IncreaseHit("a")
}

View File

@@ -13,12 +13,16 @@ import (
const maxItemsPerGroup = 50_000
var SharedCounter = NewCounter().WithGC()
var SharedCounter = NewCounter[uint32]().WithGC()
type Counter struct {
type SupportedUIntType interface {
uint32 | uint64
}
type Counter[T SupportedUIntType] struct {
countMaps uint64
locker *syncutils.RWMutex
itemMaps []map[uint64]*Item
itemMaps []map[uint64]*Item[T]
gcTicker *time.Ticker
gcIndex int
@@ -26,18 +30,18 @@ type Counter struct {
}
// NewCounter create new counter
func NewCounter() *Counter {
func NewCounter[T SupportedUIntType]() *Counter[T] {
var count = utils.SystemMemoryGB() * 8
if count < 8 {
count = 8
}
var itemMaps = []map[uint64]*Item{}
var itemMaps = []map[uint64]*Item[T]{}
for i := 0; i < count; i++ {
itemMaps = append(itemMaps, map[uint64]*Item{})
itemMaps = append(itemMaps, map[uint64]*Item[T]{})
}
var counter = &Counter{
var counter = &Counter[T]{
countMaps: uint64(count),
locker: syncutils.NewRWMutex(count),
itemMaps: itemMaps,
@@ -47,7 +51,7 @@ func NewCounter() *Counter {
}
// WithGC start the counter with gc automatically
func (this *Counter) WithGC() *Counter {
func (this *Counter[T]) WithGC() *Counter[T] {
if this.gcTicker != nil {
return this
}
@@ -62,23 +66,17 @@ func (this *Counter) WithGC() *Counter {
}
// Increase key
func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 {
func (this *Counter[T]) Increase(key uint64, lifeSeconds int) T {
var index = int(key % this.countMaps)
this.locker.RLock(index)
var item = this.itemMaps[index][key]
this.locker.RUnlock(index)
if item == nil { // no need to care about duplication
item = NewItem(lifeSeconds)
if item == nil {
// no need to care about duplication
// always insert new item even when itemMap is full
item = NewItem[T](lifeSeconds)
this.locker.Lock(index)
// check again
oldItem, ok := this.itemMaps[index][key]
if !ok {
this.itemMaps[index][key] = item
} else {
item = oldItem
}
this.itemMaps[index][key] = item
this.locker.Unlock(index)
}
@@ -89,12 +87,12 @@ func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 {
}
// IncreaseKey increase string key
func (this *Counter) IncreaseKey(key string, lifeSeconds int) uint64 {
func (this *Counter[T]) IncreaseKey(key string, lifeSeconds int) T {
return this.Increase(this.hash(key), lifeSeconds)
}
// Get value of key
func (this *Counter) Get(key uint64) uint64 {
func (this *Counter[T]) Get(key uint64) T {
var index = int(key % this.countMaps)
this.locker.RLock(index)
defer this.locker.RUnlock(index)
@@ -106,12 +104,12 @@ func (this *Counter) Get(key uint64) uint64 {
}
// GetKey get value of string key
func (this *Counter) GetKey(key string) uint64 {
func (this *Counter[T]) GetKey(key string) T {
return this.Get(this.hash(key))
}
// Reset key
func (this *Counter) Reset(key uint64) {
func (this *Counter[T]) Reset(key uint64) {
var index = int(key % this.countMaps)
this.locker.RLock(index)
var item = this.itemMaps[index][key]
@@ -125,12 +123,12 @@ func (this *Counter) Reset(key uint64) {
}
// ResetKey string key
func (this *Counter) ResetKey(key string) {
func (this *Counter[T]) ResetKey(key string) {
this.Reset(this.hash(key))
}
// TotalItems get items count
func (this *Counter) TotalItems() int {
func (this *Counter[T]) TotalItems() int {
var total = 0
for i := 0; i < int(this.countMaps); i++ {
@@ -143,7 +141,7 @@ func (this *Counter) TotalItems() int {
}
// GC garbage expired items
func (this *Counter) GC() {
func (this *Counter[T]) GC() {
this.gcLocker.Lock()
var gcIndex = this.gcIndex
@@ -192,11 +190,11 @@ func (this *Counter) GC() {
}
}
func (this *Counter) CountMaps() int {
func (this *Counter[T]) CountMaps() int {
return int(this.countMaps)
}
// calculate hash of the key
func (this *Counter) hash(key string) uint64 {
func (this *Counter[T]) hash(key string) uint64 {
return xxhash.Sum64String(key)
}

View File

@@ -19,7 +19,7 @@ import (
func TestCounter_Increase(t *testing.T) {
var a = assert.NewAssertion(t)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
a.IsTrue(counter.Increase(1, 10) == 1)
a.IsTrue(counter.Increase(1, 10) == 2)
a.IsTrue(counter.Increase(2, 10) == 1)
@@ -32,7 +32,7 @@ func TestCounter_Increase(t *testing.T) {
func TestCounter_IncreaseKey(t *testing.T) {
var a = assert.NewAssertion(t)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
a.IsTrue(counter.IncreaseKey("1", 10) == 1)
a.IsTrue(counter.IncreaseKey("1", 10) == 2)
a.IsTrue(counter.IncreaseKey("2", 10) == 1)
@@ -47,13 +47,14 @@ func TestCounter_GC(t *testing.T) {
return
}
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
counter.Increase(1, 20)
time.Sleep(1 * time.Second)
counter.Increase(1, 20)
time.Sleep(1 * time.Second)
counter.Increase(1, 20)
counter.GC()
t.Log(counter.Get(1))
}
func TestCounter_GC2(t *testing.T) {
@@ -61,8 +62,8 @@ func TestCounter_GC2(t *testing.T) {
return
}
var counter = counters.NewCounter().WithGC()
for i := 0; i < 1e5; i++ {
var counter = counters.NewCounter[uint32]().WithGC()
for i := 0; i < 100_000; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -79,7 +80,7 @@ func TestCounterMemory(t *testing.T) {
var stat = &runtime.MemStats{}
runtime.ReadMemStats(stat)
var counter = counters.NewCounter().WithGC()
var counter = counters.NewCounter[uint32]()
for i := 0; i < 1_000_000; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -90,15 +91,28 @@ func TestCounterMemory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
t.Log((stat1.TotalAlloc-stat.TotalAlloc)/(1<<20), "MB")
t.Log((stat1.HeapInuse-stat.HeapInuse)/(1<<20), "MB")
t.Log(counter.TotalItems())
var gcPause = func() {
var before = time.Now()
runtime.GC()
var costSeconds = time.Since(before).Seconds()
var stats = &debug.GCStats{}
debug.ReadGCStats(stats)
t.Log("GC pause:", stats.PauseTotal.Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
}
gcPause()
_ = counter.TotalItems()
}
func BenchmarkCounter_Increase(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
b.ResetTimer()
var i uint64
@@ -114,7 +128,7 @@ func BenchmarkCounter_Increase(b *testing.B) {
func BenchmarkCounter_IncreaseKey(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
go func() {
var ticker = time.NewTicker(100 * time.Millisecond)
@@ -138,7 +152,7 @@ func BenchmarkCounter_IncreaseKey(b *testing.B) {
func BenchmarkCounter_IncreaseKey2(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
go func() {
var ticker = time.NewTicker(1 * time.Millisecond)
@@ -162,7 +176,7 @@ func BenchmarkCounter_IncreaseKey2(b *testing.B) {
func BenchmarkCounter_GC(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
for i := uint64(0); i < 1e5; i++ {
counter.IncreaseKey(types.String(i), 20)

View File

@@ -6,40 +6,43 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
)
type Item struct {
lifeSeconds int64
spanSeconds int64
spans []uint64
const spanMaxValue = 10_000_000
const maxSpans = 10
type Item[T SupportedUIntType] struct {
spans [maxSpans + 1]T
lastUpdateTime int64
lifeSeconds int64
spanSeconds int64
}
func NewItem(lifeSeconds int) *Item {
func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
if lifeSeconds <= 0 {
lifeSeconds = 60
}
var spanSeconds = lifeSeconds / 10
var spanSeconds = lifeSeconds / maxSpans
if spanSeconds < 1 {
spanSeconds = 1
} else if lifeSeconds > maxSpans && lifeSeconds%maxSpans != 0 {
spanSeconds++
}
var countSpans = lifeSeconds/spanSeconds + 1 /** prevent index out of bounds **/
return &Item{
return &Item[T]{
lifeSeconds: int64(lifeSeconds),
spanSeconds: int64(spanSeconds),
spans: make([]uint64, countSpans),
lastUpdateTime: fasttime.Now().Unix(),
}
}
func (this *Item) Increase() (result uint64) {
func (this *Item[T]) Increase() (result T) {
var currentTime = fasttime.Now().Unix()
var currentSpanIndex = this.calculateSpanIndex(currentTime)
// return quickly
if this.lastUpdateTime == currentTime {
this.spans[currentSpanIndex]++
if this.spans[currentSpanIndex] < spanMaxValue {
this.spans[currentSpanIndex]++
}
for _, count := range this.spans {
result += count
}
@@ -69,7 +72,9 @@ func (this *Item) Increase() (result uint64) {
}
}
this.spans[currentSpanIndex]++
if this.spans[currentSpanIndex] < spanMaxValue {
this.spans[currentSpanIndex]++
}
this.lastUpdateTime = currentTime
for _, count := range this.spans {
@@ -79,7 +84,7 @@ func (this *Item) Increase() (result uint64) {
return
}
func (this *Item) Sum() (result uint64) {
func (this *Item[T]) Sum() (result T) {
if this.lastUpdateTime == 0 {
return 0
}
@@ -104,16 +109,20 @@ func (this *Item) Sum() (result uint64) {
return result
}
func (this *Item) Reset() {
func (this *Item[T]) Reset() {
for index := range this.spans {
this.spans[index] = 0
}
}
func (this *Item) IsExpired(currentTime int64) bool {
func (this *Item[T]) IsExpired(currentTime int64) bool {
return this.lastUpdateTime < currentTime-this.lifeSeconds-this.spanSeconds
}
func (this *Item) calculateSpanIndex(timestamp int64) int {
return int(timestamp % this.lifeSeconds / this.spanSeconds)
func (this *Item[T]) calculateSpanIndex(timestamp int64) int {
var index = int(timestamp % this.lifeSeconds / this.spanSeconds)
if index > maxSpans-1 {
return maxSpans - 1
}
return index
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"runtime"
"testing"
@@ -17,7 +18,7 @@ func TestItem_Increase(t *testing.T) {
return
}
var item = counters.NewItem(10)
var item = counters.NewItem[uint32](10)
t.Log(item.Increase(), item.Sum())
time.Sleep(1 * time.Second)
t.Log(item.Increase(), item.Sum())
@@ -41,9 +42,9 @@ func TestItem_Increase2(t *testing.T) {
var a = assert.NewAssertion(t)
var item = counters.NewItem(20)
var item = counters.NewItem[uint32](23)
for i := 0; i < 100; i++ {
t.Log(item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
t.Log("round "+types.String(i)+":", item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
time.Sleep(2 * time.Second)
}
@@ -56,14 +57,14 @@ func TestItem_IsExpired(t *testing.T) {
return
}
var currentTime = time.Now().Unix()
var item = counters.NewItem(10)
t.Log(item.IsExpired(currentTime))
var item = counters.NewItem[uint32](10)
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(10 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(time.Now().Unix()))
}
func BenchmarkItem_Increase(b *testing.B) {
@@ -73,7 +74,7 @@ func BenchmarkItem_Increase(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var item = counters.NewItem(60)
var item = counters.NewItem[uint32](60)
item.Increase()
item.Sum()
}

View File

@@ -2,7 +2,7 @@
package linkedlist
type List[T any] struct {
type List[T any] struct {
head *Item[T]
end *Item[T]
count int
@@ -36,6 +36,15 @@ func (this *List[T]) Push(item *Item[T]) {
this.add(item)
}
func (this *List[T]) Shift() *Item[T] {
if this.head != nil {
var old = this.head
this.Remove(this.head)
return old
}
return nil
}
func (this *List[T]) Remove(item *Item[T]) {
if item == nil {
return
@@ -71,6 +80,15 @@ func (this *List[T]) Range(f func(item *Item[T]) (goNext bool)) {
}
}
func (this *List[T]) RangeReverse(f func(item *Item[T]) (goNext bool)) {
for e := this.end; e != nil; e = e.prev {
goNext := f(e)
if !goNext {
break
}
}
}
func (this *List[T]) Reset() {
this.head = nil
this.end = nil

View File

@@ -4,6 +4,7 @@ package linkedlist_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
"github.com/iwind/TeaGo/types"
"runtime"
"strconv"
"testing"
@@ -95,6 +96,48 @@ func TestList_Push(t *testing.T) {
})
}
func TestList_Shift(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
list.Push(linkedlist.NewItem(3))
list.Push(linkedlist.NewItem(4))
for i := 0; i < 10; i++ {
t.Log("=== before shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
t.Logf("shift: %+v", list.Shift())
t.Log("=== after shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
}
func TestList_RangeReverse(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
var item3 = linkedlist.NewItem(3)
list.Push(item3)
list.Push(linkedlist.NewItem(4))
//list.Push(item3)
//list.Remove(item3)
list.RangeReverse(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
func BenchmarkList_Add(b *testing.B) {
var list = linkedlist.NewList[int]()
for i := 0; i < b.N; i++ {

View File

@@ -0,0 +1,170 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes
// ContainsAnyWordRunes 直接使用rune检查字符串是否包含任一单词
func ContainsAnyWordRunes(s string, words [][]rune, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, wordRunes := range words {
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAnyWord 检查字符串是否包含任一单词
func ContainsAnyWord(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, word := range words {
var wordRunes = []rune(word)
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAllWords 检查字符串是否包含所有单词
func ContainsAllWords(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
for _, word := range words {
if result, _ := ContainsWordRunes(allRunes, []rune(word), isCaseInsensitive); !result {
return false
}
}
return true
}
// ContainsWordRunes 检查字符列表是否包含某个单词子字符列表
func ContainsWordRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) (result bool, firstIndex int) {
firstIndex = -1
var l = len(subRunes)
if l == 0 {
return false, 0
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) && (index == 0 || !isChar(allRunes[index-1]) /**boundary check **/) {
if firstIndex < 0 {
firstIndex = index
}
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found && (al <= index+l || !isChar(allRunes[index+l]) /**boundary check **/) {
return true, firstIndex
}
}
}
return false, firstIndex
}
// ContainsSubRunes 检查字符列表是否包含某个子子字符列表
// 与 ContainsWordRunes 不同,这里不需要检查边界符号
func ContainsSubRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) bool {
var l = len(subRunes)
if l == 0 {
return false
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) {
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found {
return true
}
}
}
return false
}
// EqualRune 判断两个rune是否相同
func EqualRune(r1 rune, r2 rune, isCaseInsensitive bool) bool {
const d = 'a' - 'A'
return r1 == r2 ||
(isCaseInsensitive && r1 >= 'a' && r1 <= 'z' && r1-r2 == d) ||
(isCaseInsensitive && r1 >= 'A' && r1 <= 'Z' && r1-r2 == -d)
}
func isChar(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9'
}

View File

@@ -0,0 +1,172 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes_test
import (
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/TeaOSLab/EdgeNode/internal/utils/runes"
"github.com/iwind/TeaGo/assert"
"regexp"
"runtime"
"sort"
"strings"
"testing"
)
func TestContainsAllWords(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"are", "you"}, false))
a.IsFalse(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, false))
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, true))
}
func TestContainsAnyWord(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you", "ok"}, false))
a.IsFalse(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how"}, true))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, true))
a.IsTrue(runes.ContainsAnyWord("How-are you?", []string{"how", "ok"}, true))
}
func TestContainsAnyWord_Sort(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"abc", "ant", "arm", "Hit", "Hi", "Pet", "pie", "are"}, false))
}
func TestContainsWordRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsWordRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("are"), false))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("re"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you w?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("w How are you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are w you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are how you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("ARE"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU YOU YOU YOU YOU YOU YOU"), true))
}
func TestContainsSubRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsSubRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("ow"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("H"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("How"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("oi"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("g"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("ing"), false))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you doing"), []rune("int"), false))
}
func TestEqualRune(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.EqualRune('a', 'a', false))
a.IsTrue(runes.EqualRune('a', 'a', true))
a.IsFalse(runes.EqualRune('a', 'A', false))
a.IsTrue(runes.EqualRune('a', 'A', true))
a.IsFalse(runes.EqualRune('c', 'C', false))
a.IsTrue(runes.EqualRune('c', 'C', true))
a.IsTrue(runes.EqualRune('C', 'C', true))
a.IsTrue(runes.EqualRune('C', 'c', true))
a.IsTrue(runes.EqualRune('Z', 'z', true))
a.IsTrue(runes.EqualRune('z', 'Z', true))
a.IsFalse(runes.EqualRune('z', 'z'+('a'-'A'), true))
}
func BenchmarkContainsWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}
func BenchmarkContainsAnyWord(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWord("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", words, true)
}
})
}
func BenchmarkContainsAnyWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWordRunes("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", wordRunes, true)
}
})
}
func BenchmarkContainsAnyWord_Regexp(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = regexp.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsAnyWord_Re(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = re.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsSubRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsSubRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}

View File

@@ -56,3 +56,16 @@ func EqualStrings(s1 []string, s2 []string) bool {
}
return true
}
// CutPrefix returns s without the provided leading prefix string
// and reports whether it found the prefix.
// If s doesn't start with prefix, CutPrefix returns s, false.
// If prefix is the empty string, CutPrefix returns s, true.
//
// copy from go source
func CutPrefix(s, prefix string) (after string, found bool) {
if !strings.HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}

View File

@@ -13,6 +13,6 @@ func setMaxMemory(memoryGB int) {
memoryGB = 1
}
var maxMemoryBytes = (int64(memoryGB) << 30) * 75 / 100 // 默认 75%
var maxMemoryBytes = (int64(memoryGB) << 30) * 80 / 100 // 默认 80%
debug.SetMemoryLimit(maxMemoryBytes)
}

View File

@@ -1,11 +1,12 @@
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
wafutils "github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"net/http"
"net/url"
"strings"
@@ -27,6 +28,8 @@ type CaptchaAction struct {
CountLetters int8 `yaml:"countLetters" json:"countLetters"`
CaptchaType firewallconfigs.CaptchaType `yaml:"captchaType" json:"captchaType"`
UIIsOn bool `yaml:"uiIsOn" json:"uiIsOn"` // 是否使用自定义UI
UITitle string `yaml:"uiTitle" json:"uiTitle"` // 消息标题
UIPrompt string `yaml:"uiPrompt" json:"uiPrompt"` // 消息提示
@@ -36,6 +39,24 @@ type CaptchaAction struct {
UIFooter string `yaml:"uiFooter" json:"uiFooter"` // 页脚
UIBody string `yaml:"uiBody" json:"uiBody"` // 内容轮廓
OneClickUIIsOn bool `yaml:"oneClickUIIsOn" json:"oneClickUIIsOn"` // 是否使用自定义UI
OneClickUITitle string `yaml:"oneClickUITitle" json:"oneClickUITitle"` // 消息标题
OneClickUIPrompt string `yaml:"oneClickUIPrompt" json:"oneClickUIPrompt"` // 消息提示
OneClickUIShowRequestId bool `yaml:"oneClickUIShowRequestId" json:"oneClickUIShowRequestId"` // 是否显示请求ID
OneClickUICss string `yaml:"oneClickUICss" json:"oneClickUICss"` // CSS样式
OneClickUIFooter string `yaml:"oneClickUIFooter" json:"oneClickUIFooter"` // 页脚
OneClickUIBody string `yaml:"oneClickUIBody" json:"oneClickUIBody"` // 内容轮廓
SlideUIIsOn bool `yaml:"sliceUIIsOn" json:"sliceUIIsOn"` // 是否使用自定义UI
SlideUITitle string `yaml:"slideUITitle" json:"slideUITitle"` // 消息标题
SlideUIPrompt string `yaml:"slideUIPrompt" json:"slideUIPrompt"` // 消息提示
SlideUIShowRequestId bool `yaml:"SlideUIShowRequestId" json:"SlideUIShowRequestId"` // 是否显示请求ID
SlideUICss string `yaml:"slideUICss" json:"slideUICss"` // CSS样式
SlideUIFooter string `yaml:"slideUIFooter" json:"slideUIFooter"` // 页脚
SlideUIBody string `yaml:"slideUIBody" json:"slideUIBody"` // 内容轮廓
GeeTestConfig *firewallconfigs.GeeTestConfig `yaml:"geeTestConfig" json:"geeTestConfig"` // 极验设置 MUST be struct
Lang string `yaml:"lang" json:"lang"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
@@ -81,6 +102,10 @@ func (this *CaptchaAction) Init(waf *WAF) error {
if len(this.Lang) == 0 {
this.Lang = waf.DefaultCaptchaAction.Lang
}
if len(this.CaptchaType) == 0 {
this.CaptchaType = waf.DefaultCaptchaAction.CaptchaType
}
}
return nil
@@ -100,7 +125,7 @@ func (this *CaptchaAction) WillChange() bool {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false
}
@@ -134,6 +159,7 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
// 占用一次失败次数
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
req.DisableStat()
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)

View File

@@ -67,6 +67,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
return true, false
}
request.DisableStat()
request.ProcessResponseHeaders(writer.Header(), http.StatusFound)
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)

View File

@@ -36,10 +36,30 @@ func (this *PageAction) WillChange() bool {
// Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
if writer == nil {
return
}
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(this.Status)
_, _ = writer.Write([]byte(request.Format(this.Body)))
var body = this.Body
if len(body) == 0 {
body = `<!DOCTYPE html>
<html lang="en">
<title>403 Forbidden</title>
<style>
address { line-height: 1.8; }
</style>
<body>
<h1>403 Forbidden By WAF</h1>
<address>Connection: ${remoteAddr} (Client) -&gt; ${serverAddr} (Server)</address>
<address>Request ID: ${requestId}</address>
</body>
</html>`
}
_, _ = writer.Write([]byte(request.Format(body)))
return false, false
}

View File

@@ -92,6 +92,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
Value: info,
})
request.DisableStat()
request.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)

View File

@@ -1,6 +1,7 @@
package waf
package waf_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
@@ -11,22 +12,22 @@ import (
func TestFindActionInstance(t *testing.T) {
a := assert.NewAssertion(t)
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", waf.FindActionInstance(waf.ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
a.IsTrue(waf.FindActionInstance(waf.ActionGoSet, nil) != waf.FindActionInstance(waf.ActionGoSet, nil))
}
func TestFindActionInstance_Options(t *testing.T) {
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
logs.PrintAsJSON(waf.FindActionInstance(waf.ActionBlock, maps.Map{
"timeout": 3600,
}), t)
}
@@ -34,6 +35,6 @@ func TestFindActionInstance_Options(t *testing.T) {
func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
FindActionInstance(ActionGoSet, nil)
waf.FindActionInstance(waf.ActionGoSet, nil)
}
}

View File

@@ -4,7 +4,6 @@ package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
@@ -16,6 +15,7 @@ type CaptchaPageCode = string
const (
CaptchaPageCodeInit CaptchaPageCode = "init"
CaptchaPageCodeShow CaptchaPageCode = "show"
CaptchaPageCodeImage CaptchaPageCode = "image"
CaptchaPageCodeSubmit CaptchaPageCode = "submit"
)
@@ -40,19 +40,11 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol
func CaptchaDeleteCacheKey(req requests.Request) {
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeInit))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeShow))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeImage))
counters.SharedCounter.ResetKey(CaptchaCacheKey(req, CaptchaPageCodeSubmit))
}
// CaptchaCacheKey 获取Captcha缓存Key
func CaptchaCacheKey(req requests.Request, pageCode CaptchaPageCode) string {
var requestPath = req.WAFRaw().URL.Path
if req.WAFRaw().URL.Path == CaptchaPath {
m, err := utils.SimpleDecryptMap(req.WAFRaw().URL.Query().Get("info"))
if err == nil && m != nil {
requestPath = m.GetString("url")
}
}
return "WAF:CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId()) + ":" + requestPath
return "WAF:CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId())
}

View File

@@ -0,0 +1,71 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf
import (
"bytes"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/rands"
"io"
"time"
)
// CaptchaGenerator captcha generator
type CaptchaGenerator struct {
store captcha.Store
}
func NewCaptchaGenerator() *CaptchaGenerator {
return &CaptchaGenerator{
store: captcha.NewMemoryStore(100_000, 5*time.Minute),
}
}
// NewCaptcha create new captcha
func (this *CaptchaGenerator) NewCaptcha(length int) (captchaId string) {
captchaId = rands.HexString(16)
if length <= 0 || length > 20 {
length = 4
}
this.store.Set(captchaId, captcha.RandomDigits(length))
return
}
// WriteImage write image to front writer
func (this *CaptchaGenerator) WriteImage(w io.Writer, id string, width, height int) error {
var d = this.store.Get(id, false)
if d == nil {
return captcha.ErrNotFound
}
_, err := captcha.NewImage(id, d, width, height).WriteTo(w)
return err
}
// Verify user input
func (this *CaptchaGenerator) Verify(id string, digits string) bool {
var countDigits = len(digits)
if countDigits == 0 {
return false
}
var value = this.store.Get(id, true)
if len(value) != countDigits {
return false
}
var nb = make([]byte, countDigits)
for i := 0; i < countDigits; i++ {
var d = digits[i]
if d >= '0' && d <= '9' {
nb[i] = d - '0'
}
}
return bytes.Equal(nb, value)
}
// Get captcha data
func (this *CaptchaGenerator) Get(id string) []byte {
return this.store.Get(id, false)
}

View File

@@ -0,0 +1,87 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/types"
"runtime"
"strings"
"testing"
"time"
)
func TestCaptchaGenerator_NewCaptcha(t *testing.T) {
var a = assert.NewAssertion(t)
var generator = waf.NewCaptchaGenerator()
var captchaId = generator.NewCaptcha(6)
t.Log("captchaId:", captchaId)
var digits = generator.Get(captchaId)
var s []string
for _, digit := range digits {
s = append(s, types.String(digit))
}
t.Log(strings.Join(s, " "))
a.IsTrue(generator.Verify(captchaId, strings.Join(s, "")))
a.IsFalse(generator.Verify(captchaId, strings.Join(s, "")))
}
func TestCaptchaGenerator_NewCaptcha_UTF8(t *testing.T) {
var a = assert.NewAssertion(t)
var generator = waf.NewCaptchaGenerator()
var captchaId = generator.NewCaptcha(6)
t.Log("captchaId:", captchaId)
var digits = generator.Get(captchaId)
var s []string
for _, digit := range digits {
s = append(s, types.String(digit))
}
t.Log(strings.Join(s, " "))
a.IsFalse(generator.Verify(captchaId, "中文真的很长"))
}
func TestCaptchaGenerator_NewCaptcha_Memory(t *testing.T) {
runtime.GC()
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var generator = waf.NewCaptchaGenerator()
for i := 0; i < 1_000_000; i++ {
generator.NewCaptcha(6)
}
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
runtime.GC()
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)>>10, "KiB")
_ = generator
}
func BenchmarkNewCaptchaGenerator(b *testing.B) {
runtime.GOMAXPROCS(4)
var generator = waf.NewCaptchaGenerator()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
generator.NewCaptcha(6)
}
})
}

View File

@@ -0,0 +1,70 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf_test
import (
"bytes"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/dchest/captcha"
"runtime"
"testing"
"time"
)
func TestCaptchaMemory(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var count = 5_000
var before = time.Now()
for i := 0; i < count; i++ {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 200, 100)
if err != nil {
t.Fatal(err)
}
captcha.VerifyString(id, "abc")
}
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)>>20, "MB", fmt.Sprintf("%.0f QPS", float64(count)/time.Since(before).Seconds()))
}
func BenchmarkCaptcha_VerifyCode_100_50(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 100, 50)
if err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkCaptcha_VerifyCode_200_100(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var id = captcha.NewLen(6)
var writer = &bytes.Buffer{}
err := captcha.WriteImage(writer, id, 200, 100)
if err != nil {
b.Fatal(err)
}
_ = id
}
})
}

File diff suppressed because one or more lines are too long

View File

@@ -76,7 +76,8 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
}
var ccKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(keyValues, "@")
value = counters.SharedCounter.IncreaseKey(ccKey, period)
var ccValue = counters.SharedCounter.IncreaseKey(ccKey, period)
value = ccValue
// 基于指纹统计
var enableFingerprint = true
@@ -96,7 +97,7 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
}
var fpCCKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(fpKeyValues, "@")
var fpValue = counters.SharedCounter.IncreaseKey(fpCCKey, period)
if fpValue > value.(uint64) {
if fpValue > ccValue {
value = fpValue
}
}

View File

@@ -12,11 +12,11 @@ type RequestAllCheckpoint struct {
}
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var valueBytes = []byte{}
var valueBytes = [][]byte{}
if len(req.WAFRaw().RequestURI) > 0 {
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
valueBytes = append(valueBytes, []byte(req.WAFRaw().RequestURI))
} else if req.WAFRaw().URL != nil {
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
valueBytes = append(valueBytes, []byte(req.WAFRaw().URL.RequestURI()))
}
if this.RequestBodyIsEmpty(req) {
@@ -25,8 +25,6 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
}
if req.WAFRaw().Body != nil {
valueBytes = append(valueBytes, ' ')
var bodyData = req.WAFGetCacheBody()
hasRequestBody = true
if len(bodyData) == 0 {
@@ -39,7 +37,9 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
req.WAFSetCacheBody(data)
req.WAFRestoreBody(data)
}
valueBytes = append(valueBytes, bodyData...)
if len(bodyData) > 0 {
valueBytes = append(valueBytes, bodyData)
}
}
value = valueBytes

View File

@@ -25,8 +25,14 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
if userErr != nil {
t.Fatal(userErr)
}
t.Log(v)
t.Log(types.String(v))
if v != nil {
vv, ok := v.([][]byte)
if ok {
for _, v2 := range vv {
t.Log(string(v2), ":", v2)
}
}
}
body, err := io.ReadAll(req.Body)
if err != nil {

View File

@@ -0,0 +1,32 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/maps"
"strings"
)
type RequestHeaderNamesCheckpoint struct {
Checkpoint
}
func (this *RequestHeaderNamesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var headerNames = []string{}
for k := range req.WAFRaw().Header {
headerNames = append(headerNames, k)
}
value = strings.Join(headerNames, "\n")
return
}
func (this *RequestHeaderNamesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options, ruleId)
}
return
}
func (this *RequestHeaderNamesCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheShortLife
}

View File

@@ -0,0 +1,23 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestHeaderNamesCheckpoint_RequestValue(t *testing.T) {
var checkpoint = &checkpoints.RequestHeaderNamesCheckpoint{}
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
if err != nil {
t.Fatal(err)
}
rawReq.Header.Set("Accept", "text/html")
rawReq.Header.Set("User-Agent", "Chrome")
rawReq.Header.Set("Accept-Encoding", "br, gzip")
var req = requests.NewTestRequest(rawReq)
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}

View File

@@ -23,5 +23,5 @@ func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *
}
func (this *RequestRefererCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheShortLife
return utils.CacheMiddleLife
}

View File

@@ -0,0 +1,44 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/maps"
)
type RequestRefererOriginCheckpoint struct {
Checkpoint
}
func (this *RequestRefererOriginCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var s []string
var referer = req.WAFRaw().Referer()
if len(referer) > 0 {
s = append(s, referer)
}
var origin = req.WAFRaw().Header.Get("Origin")
if len(origin) > 0 {
s = append(s, origin)
}
if len(s) > 0 {
value = s
} else {
value = ""
}
return
}
func (this *RequestRefererOriginCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options, ruleId)
}
return
}
func (this *RequestRefererOriginCheckpoint) CacheLife() utils.CacheLife {
return utils.CacheMiddleLife
}

View File

@@ -0,0 +1,38 @@
package checkpoints_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestRefererOriginCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
if err != nil {
t.Fatal(err)
}
var req = requests.NewTestRequest(rawReq)
var checkpoint = &checkpoints.RequestRefererOriginCheckpoint{}
{
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Set("Referer", "https://example.com/hello.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Set("Origin", "https://example.com/world.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
{
rawReq.Header.Del("Referer")
rawReq.Header.Set("Origin", "https://example.com/world.yaml")
t.Log(checkpoint.RequestValue(req, "", nil, 0))
}
}

View File

@@ -163,7 +163,15 @@ var AllCheckpoints = []*CheckpointDefinition{
Priority: 100,
},
{
Name: "请求来源URL",
Name: "请求来源",
Prefix: "refererOrigin",
Description: "请求报头中的Referer或Origin值",
HasParams: false,
Instance: new(RequestRefererOriginCheckpoint),
Priority: 100,
},
{
Name: "请求来源Referer",
Prefix: "referer",
Description: "请求Header中的Referer值",
HasParams: false,
@@ -226,6 +234,14 @@ var AllCheckpoints = []*CheckpointDefinition{
Instance: new(RequestHeadersCheckpoint),
Priority: 100,
},
{
Name: "所有请求报头名称",
Prefix: "headerNames",
Description: "使用换行符(\\n隔开的报头名称字符串每行一个名称",
HasParams: false,
Instance: new(RequestHeaderNamesCheckpoint),
Priority: 100,
},
{
Name: "单个Header值",
Prefix: "header",

View File

@@ -0,0 +1,32 @@
Copyright (c) 2012-2016, Nick Galbreath
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
https://github.com/client9/libinjection
http://opensource.org/licenses/BSD-3-Clause

View File

@@ -0,0 +1 @@
copy from https://github.com/libinjection/libinjection

View File

@@ -0,0 +1,65 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see COPYING.txt for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_H
#define LIBINJECTION_H
#ifdef __cplusplus
# define LIBINJECTION_BEGIN_DECLS extern "C" {
# define LIBINJECTION_END_DECLS }
#else
# define LIBINJECTION_BEGIN_DECLS
# define LIBINJECTION_END_DECLS
#endif
LIBINJECTION_BEGIN_DECLS
/*
* Pull in size_t
*/
#include <string.h>
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
* Simple API for SQLi detection - returns a SQLi fingerprint or NULL
* is benign input
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \param[out] fingerprint buffer of 8+ characters. c-string,
* \return 1 if SQLi, 0 if benign. fingerprint will be set or set to empty string.
*/
int libinjection_sqli(const char* s, size_t slen, char fingerprint[]);
/** ALPHA version of xss detector.
*
* NOT DONE.
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \return 1 if XSS found, 0 if benign
*
*/
int libinjection_xss(const char* s, size_t slen);
LIBINJECTION_END_DECLS
#endif /* LIBINJECTION_H */

View File

@@ -0,0 +1,868 @@
#include "libinjection_html5.h"
#include <string.h>
#include <assert.h>
#ifdef DEBUG
#include <stdio.h>
#define TRACE() printf("%s:%d\n", __FUNCTION__, __LINE__)
#else
#define TRACE()
#endif
#define CHAR_EOF -1
#define CHAR_NULL 0
#define CHAR_BANG 33
#define CHAR_DOUBLE 34
#define CHAR_PERCENT 37
#define CHAR_SINGLE 39
#define CHAR_DASH 45
#define CHAR_SLASH 47
#define CHAR_LT 60
#define CHAR_EQUALS 61
#define CHAR_GT 62
#define CHAR_QUESTION 63
#define CHAR_RIGHTB 93
#define CHAR_TICK 96
/* prototypes */
static int h5_skip_white(h5_state_t* hs);
static int h5_is_white(char ch);
static int h5_state_eof(h5_state_t* hs);
static int h5_state_data(h5_state_t* hs);
static int h5_state_tag_open(h5_state_t* hs);
static int h5_state_tag_name(h5_state_t* hs);
static int h5_state_tag_name_close(h5_state_t* hs);
static int h5_state_end_tag_open(h5_state_t* hs);
static int h5_state_self_closing_start_tag(h5_state_t* hs);
static int h5_state_attribute_name(h5_state_t* hs);
static int h5_state_after_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_value(h5_state_t* hs);
static int h5_state_attribute_value_double_quote(h5_state_t* hs);
static int h5_state_attribute_value_single_quote(h5_state_t* hs);
static int h5_state_attribute_value_back_quote(h5_state_t* hs);
static int h5_state_attribute_value_no_quote(h5_state_t* hs);
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs);
static int h5_state_comment(h5_state_t* hs);
static int h5_state_cdata(h5_state_t* hs);
/* 12.2.4.44 */
static int h5_state_bogus_comment(h5_state_t* hs);
static int h5_state_bogus_comment2(h5_state_t* hs);
/* 12.2.4.45 */
static int h5_state_markup_declaration_open(h5_state_t* hs);
/* 8.2.4.52 */
static int h5_state_doctype(h5_state_t* hs);
/**
* public function
*/
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags flags)
{
memset(hs, 0, sizeof(h5_state_t));
hs->s = s;
hs->len = len;
switch (flags) {
case DATA_STATE:
hs->state = h5_state_data;
break;
case VALUE_NO_QUOTE:
hs->state = h5_state_before_attribute_name;
break;
case VALUE_SINGLE_QUOTE:
hs->state = h5_state_attribute_value_single_quote;
break;
case VALUE_DOUBLE_QUOTE:
hs->state = h5_state_attribute_value_double_quote;
break;
case VALUE_BACK_QUOTE:
hs->state = h5_state_attribute_value_back_quote;
break;
}
}
/**
* public function
*/
int libinjection_h5_next(h5_state_t* hs)
{
assert(hs->state != NULL);
return (*hs->state)(hs);
}
/**
* Everything below here is private
*
*/
static int h5_is_white(char ch)
{
/*
* \t = horizontal tab = 0x09
* \n = newline = 0x0A
* \v = vertical tab = 0x0B
* \f = form feed = 0x0C
* \r = cr = 0x0D
*/
return strchr(" \t\n\v\f\r", ch) != NULL;
}
static int h5_skip_white(h5_state_t* hs)
{
char ch;
while (hs->pos < hs->len) {
ch = hs->s[hs->pos];
switch (ch) {
case 0x00: /* IE only */
case 0x20:
case 0x09:
case 0x0A:
case 0x0B: /* IE only */
case 0x0C:
case 0x0D: /* IE only */
hs->pos += 1;
break;
default:
return ch;
}
}
return CHAR_EOF;
}
static int h5_state_eof(h5_state_t* hs)
{
/* eliminate unused function argument warning */
(void)hs;
return 0;
}
static int h5_state_data(h5_state_t* hs)
{
const char* idx;
TRACE();
assert(hs->len >= hs->pos);
idx = (const char*) memchr(hs->s + hs->pos, CHAR_LT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
hs->state = h5_state_eof;
if (hs->token_len == 0) {
return 0;
}
} else {
hs->token_start = hs->s + hs->pos;
hs->token_type = DATA_TEXT;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_tag_open;
if (hs->token_len == 0) {
return h5_state_tag_open(hs);
}
}
return 1;
}
/**
* 12 2.4.8
*/
static int h5_state_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_BANG) {
hs->pos += 1;
return h5_state_markup_declaration_open(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
hs->is_close = 1;
return h5_state_end_tag_open(hs);
} else if (ch == CHAR_QUESTION) {
hs->pos += 1;
return h5_state_bogus_comment(hs);
} else if (ch == CHAR_PERCENT) {
/* this is not in spec.. alternative comment format used
by IE <= 9 and Safari < 4.0.3 */
hs->pos += 1;
return h5_state_bogus_comment2(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
} else if (ch == CHAR_NULL) {
/* IE-ism NULL characters are ignored */
return h5_state_tag_name(hs);
} else {
/* user input mistake in configuring state */
if (hs->pos == 0) {
return h5_state_data(hs);
}
hs->token_start = hs->s + hs->pos - 1;
hs->token_len = 1;
hs->token_type = DATA_TEXT;
hs->state = h5_state_data;
return 1;
}
}
/**
* 12.2.4.9
*/
static int h5_state_end_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
return h5_state_data(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
}
hs->is_close = 0;
return h5_state_bogus_comment(hs);
}
/*
*
*/
static int h5_state_tag_name_close(h5_state_t* hs)
{
TRACE();
hs->is_close = 0;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
if (hs->pos < hs->len) {
hs->state = h5_state_data;
} else {
hs->state = h5_state_eof;
}
return 1;
}
/**
* 12.2.4.10
*/
static int h5_state_tag_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (ch == 0) {
/* special non-standard case */
/* allow nulls in tag name */
/* some old browsers apparently allow and ignore them */
pos += 1;
} else if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_self_closing_start_tag;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
if (hs->is_close) {
hs->pos = pos + 1;
hs->is_close = 0;
hs->token_type = TAG_CLOSE;
hs->state = h5_state_data;
} else {
hs->pos = pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_tag_name_close;
}
return 1;
} else {
pos += 1;
}
}
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_eof;
return 1;
}
/**
* 12.2.4.34
*/
static int h5_state_before_attribute_name(h5_state_t* hs)
{
int ch;
TRACE();
/* for manual tail call optimization, see comment below */
tail_call:;
ch = h5_skip_white(hs);
switch (ch) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
/* Logically, We want to call h5_state_self_closing_start_tag(hs) here.
As this function may call us back and the compiler
might not implement automatic tail call optimization,
this might result in a deep recursion.
We detect this case here and start over with the current state.
*/
if (hs->pos < hs->len && hs->s[hs->pos] != CHAR_GT) {
goto tail_call;
}
return h5_state_self_closing_start_tag(hs);
}
case CHAR_GT: {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
return 1;
}
default: {
return h5_state_attribute_name(hs);
}
}
}
static int h5_state_attribute_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos + 1;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_after_attribute_name;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_self_closing_start_tag;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_EQUALS) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_before_attribute_value;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_tag_name_close;
hs->pos = pos;
return 1;
} else {
pos += 1;
}
}
/* EOF */
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_eof;
hs->pos = hs->len;
return 1;
}
/**
* 12.2.4.36
*/
static int h5_state_after_attribute_name(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
switch (c) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
}
case CHAR_EQUALS: {
hs->pos += 1;
return h5_state_before_attribute_value(hs);
}
case CHAR_GT: {
return h5_state_tag_name_close(hs);
}
default: {
return h5_state_attribute_name(hs);
}
}
}
/**
* 12.2.4.37
*/
static int h5_state_before_attribute_value(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
if (c == CHAR_EOF) {
hs->state = h5_state_eof;
return 0;
}
if (c == CHAR_DOUBLE) {
return h5_state_attribute_value_double_quote(hs);
} else if (c == CHAR_SINGLE) {
return h5_state_attribute_value_single_quote(hs);
} else if (c == CHAR_TICK) {
/* NON STANDARD IE */
return h5_state_attribute_value_back_quote(hs);
} else {
return h5_state_attribute_value_no_quote(hs);
}
}
static int h5_state_attribute_value_quote(h5_state_t* hs, char qchar)
{
const char* idx;
TRACE();
/* skip initial quote in normal case.
* don't do this "if (pos == 0)" since it means we have started
* in a non-data state. given an input of '><foo
* we want to make 0-length attribute name
*/
if (hs->pos > 0) {
hs->pos += 1;
}
idx = (const char*) memchr(hs->s + hs->pos, qchar, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_after_attribute_value_quoted_state;
hs->pos += hs->token_len + 1;
}
return 1;
}
static
int h5_state_attribute_value_double_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_DOUBLE);
}
static
int h5_state_attribute_value_single_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_SINGLE);
}
static
int h5_state_attribute_value_back_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_TICK);
}
static int h5_state_attribute_value_no_quote(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_GT) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos;
hs->state = h5_state_tag_name_close;
return 1;
}
pos += 1;
}
TRACE();
/* EOF */
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
return 1;
}
/**
* 12.2.4.41
*/
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (h5_is_white(ch)) {
hs->pos += 1;
return h5_state_before_attribute_name(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
hs->state = h5_state_data;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.43
*
* WARNING: This function is partially inlined into h5_state_before_attribute_name()
*/
static int h5_state_self_closing_start_tag(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
assert(hs->pos > 0);
hs->token_start = hs->s + hs->pos -1;
hs->token_len = 2;
hs->token_type = TAG_NAME_SELFCLOSE;
hs->state = h5_state_data;
hs->pos += 1;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.44
*/
static int h5_state_bogus_comment(h5_state_t* hs)
{
const char* idx;
TRACE();
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_data;
}
hs->token_type = TAG_COMMENT;
return 1;
}
/**
* 12.2.4.44 ALT
*/
static int h5_state_bogus_comment2(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_PERCENT, hs->len - pos);
if (idx == NULL || (idx + 1 >= hs->s + hs->len)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->token_type = TAG_COMMENT;
hs->state = h5_state_eof;
return 1;
}
if (*(idx +1) != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* ends in %> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 2;
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
/**
* 8.2.4.45
*/
static int h5_state_markup_declaration_open(h5_state_t* hs)
{
size_t remaining;
TRACE();
remaining = hs->len - hs->pos;
if (remaining >= 7 &&
/* case insensitive */
(hs->s[hs->pos + 0] == 'D' || hs->s[hs->pos + 0] == 'd') &&
(hs->s[hs->pos + 1] == 'O' || hs->s[hs->pos + 1] == 'o') &&
(hs->s[hs->pos + 2] == 'C' || hs->s[hs->pos + 2] == 'c') &&
(hs->s[hs->pos + 3] == 'T' || hs->s[hs->pos + 3] == 't') &&
(hs->s[hs->pos + 4] == 'Y' || hs->s[hs->pos + 4] == 'y') &&
(hs->s[hs->pos + 5] == 'P' || hs->s[hs->pos + 5] == 'p') &&
(hs->s[hs->pos + 6] == 'E' || hs->s[hs->pos + 6] == 'e')
) {
return h5_state_doctype(hs);
} else if (remaining >= 7 &&
/* upper case required */
hs->s[hs->pos + 0] == '[' &&
hs->s[hs->pos + 1] == 'C' &&
hs->s[hs->pos + 2] == 'D' &&
hs->s[hs->pos + 3] == 'A' &&
hs->s[hs->pos + 4] == 'T' &&
hs->s[hs->pos + 5] == 'A' &&
hs->s[hs->pos + 6] == '['
) {
hs->pos += 7;
return h5_state_cdata(hs);
} else if (remaining >= 2 &&
hs->s[hs->pos + 0] == '-' &&
hs->s[hs->pos + 1] == '-') {
hs->pos += 2;
return h5_state_comment(hs);
}
return h5_state_bogus_comment(hs);
}
/**
* 12.2.4.48
* 12.2.4.49
* 12.2.4.50
* 12.2.4.51
* state machine spec is confusing since it can only look
* at one character at a time but simply it's comments end by:
* 1) EOF
* 2) ending in -->
* 3) ending in -!>
*/
static int h5_state_comment(h5_state_t* hs)
{
char ch;
const char* idx;
size_t pos;
size_t offset;
const char* end = hs->s + hs->len;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_DASH, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
offset = 1;
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_DASH && ch != CHAR_BANG) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* need to test */
#if 0
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
#endif
offset += 1;
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
offset += 1;
/* ends in --> or -!> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx + offset - hs->s);
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
static int h5_state_cdata(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_RIGHTB, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
return 1;
} else if ( *(idx+1) == CHAR_RIGHTB && *(idx+2) == CHAR_GT) {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 3;
hs->token_type = DATA_TEXT;
return 1;
} else {
pos = (size_t)(idx - hs->s) + 1;
}
}
}
/**
* 8.2.4.52
* http://www.w3.org/html/wg/drafts/html/master/syntax.html#doctype-state
*/
static int h5_state_doctype(h5_state_t* hs)
{
const char* idx;
TRACE();
hs->token_start = hs->s + hs->pos;
hs->token_type = DOCTYPE;
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->state = h5_state_eof;
hs->token_len = hs->len - hs->pos;
} else {
hs->state = h5_state_data;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
}
return 1;
}

View File

@@ -0,0 +1,54 @@
#ifndef LIBINJECTION_HTML5
#define LIBINJECTION_HTML5
#ifdef __cplusplus
extern "C" {
#endif
/* pull in size_t */
#include <stddef.h>
enum html5_type {
DATA_TEXT
, TAG_NAME_OPEN
, TAG_NAME_CLOSE
, TAG_NAME_SELFCLOSE
, TAG_DATA
, TAG_CLOSE
, ATTR_NAME
, ATTR_VALUE
, TAG_COMMENT
, DOCTYPE
};
enum html5_flags {
DATA_STATE
, VALUE_NO_QUOTE
, VALUE_SINGLE_QUOTE
, VALUE_DOUBLE_QUOTE
, VALUE_BACK_QUOTE
};
struct h5_state;
typedef int (*ptr_html5_state)(struct h5_state*);
typedef struct h5_state {
const char* s;
size_t len;
size_t pos;
int is_close;
ptr_html5_state state;
const char* token_start;
size_t token_len;
enum html5_type token_type;
} h5_state_t;
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags);
int libinjection_h5_next(h5_state_t* hs);
#ifdef __cplusplus
}
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,294 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see `COPYING.txt` for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_SQLI_H
#define LIBINJECTION_SQLI_H
#ifdef __cplusplus
extern "C" {
#endif
/*
* Pull in size_t
*/
#include <string.h>
enum sqli_flags {
FLAG_NONE = 0
, FLAG_QUOTE_NONE = 1 /* 1 << 0 */
, FLAG_QUOTE_SINGLE = 2 /* 1 << 1 */
, FLAG_QUOTE_DOUBLE = 4 /* 1 << 2 */
, FLAG_SQL_ANSI = 8 /* 1 << 3 */
, FLAG_SQL_MYSQL = 16 /* 1 << 4 */
};
enum lookup_type {
LOOKUP_WORD = 1
, LOOKUP_TYPE = 2
, LOOKUP_OPERATOR = 3
, LOOKUP_FINGERPRINT = 4
};
struct libinjection_sqli_token {
#ifdef SWIG
%immutable;
#endif
/*
* position and length of token
* in original string
*/
size_t pos;
size_t len;
/* count:
* in type 'v', used for number of opening '@'
* but maybe used in other contexts
*/
int count;
char type;
char str_open;
char str_close;
char val[32];
};
typedef struct libinjection_sqli_token stoken_t;
/**
* Pointer to function, takes c-string input,
* returns '\0' for no match, else a char
*/
struct libinjection_sqli_state;
typedef char (*ptr_lookup_fn)(struct libinjection_sqli_state*, int lookuptype, const char* word, size_t len);
struct libinjection_sqli_state {
#ifdef SWIG
%immutable;
#endif
/*
* input, does not need to be null terminated.
* it is also not modified.
*/
const char *s;
/*
* input length
*/
size_t slen;
/*
* How to lookup a word or fingerprint
*/
ptr_lookup_fn lookup;
void* userdata;
/*
*
*/
int flags;
/*
* pos is the index in the string during tokenization
*/
size_t pos;
#ifndef SWIG
/* for SWIG.. don't use this.. use functional API instead */
/* MAX TOKENS + 1 since we use one extra token
* to determine the type of the previous token
*/
struct libinjection_sqli_token tokenvec[8];
#endif
/*
* Pointer to token position in tokenvec, above
*/
struct libinjection_sqli_token *current;
/*
* fingerprint pattern c-string
* +1 for ending null
* Minimum of 8 bytes to add gcc's -fstack-protector to work
*/
char fingerprint[8];
/*
* Line number of code that said decided if the input was SQLi or
* not. Most of the time it's line that said "it's not a matching
* fingerprint" but there is other logic that sometimes approves
* an input. This is only useful for debugging.
*
*/
int reason;
/* Number of ddw (dash-dash-white) comments
* These comments are in the form of
* '--[whitespace]' or '--[EOF]'
*
* All databases treat this as a comment.
*/
int stats_comment_ddw;
/* Number of ddx (dash-dash-[notwhite]) comments
*
* ANSI SQL treats these are comments, MySQL treats this as
* two unary operators '-' '-'
*
* If you are parsing result returns FALSE and
* stats_comment_dd > 0, you should reparse with
* COMMENT_MYSQL
*
*/
int stats_comment_ddx;
/*
* c-style comments found /x .. x/
*/
int stats_comment_c;
/* '#' operators or MySQL EOL comments found
*
*/
int stats_comment_hash;
/*
* number of tokens folded away
*/
int stats_folds;
/*
* total tokens processed
*/
int stats_tokens;
};
typedef struct libinjection_sqli_state sfilter;
struct libinjection_sqli_token* libinjection_sqli_get_token(
struct libinjection_sqli_state* sql_state, int i);
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
*
*/
void libinjection_sqli_init(struct libinjection_sqli_state *sf,
const char* s, size_t len,
int flags);
/**
* Main API: tests for SQLi in three possible contexts, no quotes,
* single quote and double quote
*
* \param sql_state core data structure
*
* \return 1 (true) if SQLi, 0 (false) if benign
*/
int libinjection_is_sqli(struct libinjection_sqli_state* sql_state);
/* FOR HACKERS ONLY
* provides deep hooks into the decision making process
*/
void libinjection_sqli_callback(struct libinjection_sqli_state *sf,
ptr_lookup_fn fn,
void* userdata);
/*
* Resets state, but keeps initial string and callbacks
*/
void libinjection_sqli_reset(struct libinjection_sqli_state *sf,
int flags);
/**
*
*/
/**
* This detects SQLi in a single context, mostly useful for custom
* logic and debugging.
*
* \param sql_state Main data structure
* \param flags flags to adjust parsing
*
* \returns a pointer to sfilter.fingerprint as convenience
* do not free!
*
*/
const char* libinjection_sqli_fingerprint(struct libinjection_sqli_state *sql_state,
int flags);
/**
* The default "word" to token-type or fingerprint function. This
* uses a ASCII case-insensitive binary tree.
*/
char libinjection_sqli_lookup_word(struct libinjection_sqli_state *sql_state,
int lookup_type,
const char* str,
size_t len);
/* Streaming tokenization interface.
*
* sql_state->current is updated with the current token.
*
* \returns 1, has a token, keep going, or 0 no tokens
*
*/
int libinjection_sqli_tokenize(struct libinjection_sqli_state *sf);
/**
* parses and folds input, up to 5 tokens
*
*/
int libinjection_sqli_fold(struct libinjection_sqli_state *sf);
/** The built-in default function to match fingerprints
* and do false negative/positive analysis. This calls the following
* two functions. With this, you over-ride one part or the other.
*
* return libinjection_sqli_blacklist(sql_state) &&
* libinjection_sqli_not_whitelist(sql_state);
*
* \param sql_state should be filled out after libinjection_sqli_fingerprint is called
*/
int libinjection_sqli_check_fingerprint(struct libinjection_sqli_state * sql_state);
/* Given a pattern determine if it's a SQLi pattern.
*
* \return TRUE if sqli, false otherwise
*/
int libinjection_sqli_blacklist(struct libinjection_sqli_state* sql_state);
/* Given a positive match for a pattern (i.e. pattern is SQLi), this function
* does additional analysis to reduce false positives.
*
* \return TRUE if SQLi, false otherwise
*/
int libinjection_sqli_not_whitelist(struct libinjection_sqli_state * sql_state);
#ifdef __cplusplus
}
#endif
#endif /* LIBINJECTION_SQLI_H */

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,857 @@
#include "libinjection.h"
#include "libinjection_xss.h"
#include "libinjection_html5.h"
#include <assert.h>
#include <stdio.h>
typedef enum attribute {
TYPE_NONE
, TYPE_BLACK /* ban always */
, TYPE_ATTR_URL /* attribute value takes a URL-like object */
, TYPE_STYLE
, TYPE_ATTR_INDIRECT /* attribute *name* is given in *value* */
} attribute_t;
static attribute_t is_black_attr(const char* s, size_t len);
static int is_black_tag(const char* s, size_t len);
static int is_black_url(const char* s, size_t len);
static int cstrcasecmp_with_null(const char *a, const char *b, size_t n);
static int html_decode_char_at(const char* src, size_t len, size_t* consumed);
static int htmlencode_startswith(const char *a/* prefix */, const char *b /* src */, size_t n);
typedef struct stringtype {
const char* name;
attribute_t atype;
} stringtype_t;
static const int gsHexDecodeMap[256] = {
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 256, 256,
256, 256, 256, 256, 256, 10, 11, 12, 13, 14, 15, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 10, 11, 12, 13, 14, 15, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256,
256, 256, 256, 256
};
static int html_decode_char_at(const char* src, size_t len, size_t* consumed)
{
int val = 0;
size_t i;
int ch;
if (len == 0 || src == NULL) {
*consumed = 0;
return -1;
}
*consumed = 1;
if (*src != '&' || len < 2) {
return (unsigned char)(*src);
}
if (*(src+1) != '#') {
/* normally this would be for named entities
* but for this case we don't actually care
*/
return '&';
}
if (*(src+2) == 'x' || *(src+2) == 'X') {
ch = (unsigned char) (*(src+3));
ch = gsHexDecodeMap[ch];
if (ch == 256) {
/* degenerate case '&#[?]' */
return '&';
}
val = ch;
i = 4;
while (i < len) {
ch = (unsigned char) src[i];
if (ch == ';') {
*consumed = i + 1;
return val;
}
ch = gsHexDecodeMap[ch];
if (ch == 256) {
*consumed = i;
return val;
}
val = (val * 16) + ch;
if (val > 0x1000FF) {
return '&';
}
++i;
}
*consumed = i;
return val;
} else {
i = 2;
ch = (unsigned char) src[i];
if (ch < '0' || ch > '9') {
return '&';
}
val = ch - '0';
i += 1;
while (i < len) {
ch = (unsigned char) src[i];
if (ch == ';') {
*consumed = i + 1;
return val;
}
if (ch < '0' || ch > '9') {
*consumed = i;
return val;
}
val = (val * 10) + (ch - '0');
if (val > 0x1000FF) {
return '&';
}
++i;
}
*consumed = i;
return val;
}
}
/*
* These were mostly extracted from: https://raw.githubusercontent.com/WebKit/WebKit/main/Source/WebCore/dom/EventNames.h
*
* view-source:
* data:
* javascript:
* events:
*/
static stringtype_t BLACKATTREVENT[] = {
{ "ABORT", TYPE_BLACK }
, { "ACTIVATE", TYPE_BLACK }
, { "ACTIVE", TYPE_BLACK }
, { "ADDSOURCEBUFFER", TYPE_BLACK }
, { "ADDSTREAM", TYPE_BLACK }
, { "ADDTRACK", TYPE_BLACK }
, { "AFTERPRINT", TYPE_BLACK }
, { "ANIMATIONCANCEL", TYPE_BLACK }
, { "ANIMATIONEND", TYPE_BLACK }
, { "ANIMATIONITERATION", TYPE_BLACK }
, { "ANIMATIONSTART", TYPE_BLACK }
, { "AUDIOEND", TYPE_BLACK }
, { "AUDIOPROCESS", TYPE_BLACK }
, { "AUDIOSTART", TYPE_BLACK }
, { "AUTOCOMPLETEERROR", TYPE_BLACK }
, { "AUTOCOMPLETE", TYPE_BLACK }
, { "BEFOREACTIVATE", TYPE_BLACK }
, { "BEFORECOPY", TYPE_BLACK }
, { "BEFORECUT", TYPE_BLACK }
, { "BEFOREINPUT", TYPE_BLACK }
, { "BEFORELOAD", TYPE_BLACK }
, { "BEFOREPASTE", TYPE_BLACK }
, { "BEFOREPRINT", TYPE_BLACK }
, { "BEFOREUNLOAD", TYPE_BLACK }
, { "BEGINEVENT", TYPE_BLACK }
, { "BLOCKED", TYPE_BLACK }
, { "BLUR", TYPE_BLACK }
, { "BOUNDARY", TYPE_BLACK }
, { "BUFFEREDAMOUNTLOW", TYPE_BLACK }
, { "CACHED", TYPE_BLACK }
, { "CANCEL", TYPE_BLACK }
, { "CANPLAYTHROUGH", TYPE_BLACK }
, { "CANPLAY", TYPE_BLACK }
, { "CHANGE", TYPE_BLACK }
, { "CHARGINGCHANGE", TYPE_BLACK }
, { "CHARGINGTIMECHANGE", TYPE_BLACK }
, { "CHECKING", TYPE_BLACK }
, { "CLICK", TYPE_BLACK }
, { "CLOSE", TYPE_BLACK }
, { "COMPLETE", TYPE_BLACK }
, { "COMPOSITIONEND", TYPE_BLACK }
, { "COMPOSITIONSTART", TYPE_BLACK }
, { "COMPOSITIONUPDATE", TYPE_BLACK }
, { "CONNECTING", TYPE_BLACK }
, { "CONNECTIONSTATECHANGE", TYPE_BLACK }
, { "CONNECT", TYPE_BLACK }
, { "CONTEXTMENU", TYPE_BLACK }
, { "CONTROLLERCHANGE", TYPE_BLACK }
, { "COPY", TYPE_BLACK }
, { "CUECHANGE", TYPE_BLACK }
, { "CUT", TYPE_BLACK }
, { "DATAAVAILABLE", TYPE_BLACK }
, { "DATACHANNEL", TYPE_BLACK }
, { "DBLCLICK", TYPE_BLACK }
, { "DEVICECHANGE", TYPE_BLACK }
, { "DEVICEMOTION", TYPE_BLACK }
, { "DEVICEORIENTATION", TYPE_BLACK }
, { "DISCHARGINGTIMECHANGE", TYPE_BLACK }
, { "DISCONNECT", TYPE_BLACK }
, { "DOMACTIVATE", TYPE_BLACK }
, { "DOMCHARACTERDATAMODIFIED", TYPE_BLACK }
, { "DOMCONTENTLOADED", TYPE_BLACK }
, { "DOMFOCUSIN", TYPE_BLACK }
, { "DOMFOCUSOUT", TYPE_BLACK }
, { "DOMNODEINSERTEDINTODOCUMENT", TYPE_BLACK }
, { "DOMNODEINSERTED", TYPE_BLACK }
, { "DOMNODEREMOVEDFROMDOCUMENT", TYPE_BLACK }
, { "DOMNODEREMOVED", TYPE_BLACK }
, { "DOMSUBTREEMODIFIED", TYPE_BLACK }
, { "DOWNLOADING", TYPE_BLACK }
, { "DRAGEND", TYPE_BLACK }
, { "DRAGENTER", TYPE_BLACK }
, { "DRAGLEAVE", TYPE_BLACK }
, { "DRAGOVER", TYPE_BLACK }
, { "DRAGSTART", TYPE_BLACK }
, { "DRAG", TYPE_BLACK }
, { "DROP", TYPE_BLACK }
, { "DURATIONCHANGE", TYPE_BLACK }
, { "EMPTIED", TYPE_BLACK }
, { "ENCRYPTED", TYPE_BLACK }
, { "ENDED", TYPE_BLACK }
, { "ENDEVENT", TYPE_BLACK }
, { "END", TYPE_BLACK }
, { "ENTERPICTUREINPICTURE", TYPE_BLACK }
, { "ENTER", TYPE_BLACK }
, { "ERROR", TYPE_BLACK }
, { "EXIT", TYPE_BLACK }
, { "FETCH", TYPE_BLACK }
, { "FINISH", TYPE_BLACK }
, { "FOCUSIN", TYPE_BLACK }
, { "FOCUSOUT", TYPE_BLACK }
, { "FOCUS", TYPE_BLACK }
, { "FORMCHANGE", TYPE_BLACK }
, { "FORMINPUT", TYPE_BLACK }
, { "GAMEPADCONNECTED", TYPE_BLACK }
, { "GAMEPADDISCONNECTED", TYPE_BLACK }
, { "GESTURECHANGE", TYPE_BLACK }
, { "GESTUREEND", TYPE_BLACK }
, { "GESTURESCROLLEND", TYPE_BLACK }
, { "GESTURESCROLLSTART", TYPE_BLACK }
, { "GESTURESCROLLUPDATE", TYPE_BLACK }
, { "GESTURESTART", TYPE_BLACK }
, { "GESTURETAPDOWN", TYPE_BLACK }
, { "GESTURETAP", TYPE_BLACK }
, { "GOTPOINTERCAPTURE", TYPE_BLACK }
, { "HASHCHANGE", TYPE_BLACK }
, { "ICECANDIDATEERROR", TYPE_BLACK }
, { "ICECANDIDATE", TYPE_BLACK }
, { "ICECONNECTIONSTATECHANGE", TYPE_BLACK }
, { "ICEGATHERINGSTATECHANGE", TYPE_BLACK }
, { "INACTIVE", TYPE_BLACK }
, { "INPUTSOURCESCHANGE", TYPE_BLACK }
, { "INPUT", TYPE_BLACK }
, { "INSTALL", TYPE_BLACK }
, { "INVALID", TYPE_BLACK }
, { "KEYDOWN", TYPE_BLACK }
, { "KEYPRESS", TYPE_BLACK }
, { "KEYSTATUSESCHANGE", TYPE_BLACK }
, { "KEYUP", TYPE_BLACK }
, { "LANGUAGECHANGE", TYPE_BLACK }
, { "LEAVEPICTUREINPICTURE", TYPE_BLACK }
, { "LEVELCHANGE", TYPE_BLACK }
, { "LOADEDDATA", TYPE_BLACK }
, { "LOADEDMETADATA", TYPE_BLACK }
, { "LOADEND", TYPE_BLACK }
, { "LOADINGDONE", TYPE_BLACK }
, { "LOADINGERROR", TYPE_BLACK }
, { "LOADING", TYPE_BLACK }
, { "LOADSTART", TYPE_BLACK }
, { "LOAD", TYPE_BLACK }
, { "LOSTPOINTERCAPTURE", TYPE_BLACK }
, { "MARK", TYPE_BLACK }
, { "MERCHANTVALIDATION", TYPE_BLACK }
, { "MESSAGEERROR", TYPE_BLACK }
, { "MESSAGE", TYPE_BLACK }
, { "MOUSEDOWN", TYPE_BLACK }
, { "MOUSEENTER", TYPE_BLACK }
, { "MOUSELEAVE", TYPE_BLACK }
, { "MOUSEMOVE", TYPE_BLACK }
, { "MOUSEOUT", TYPE_BLACK }
, { "MOUSEOVER", TYPE_BLACK }
, { "MOUSEUP", TYPE_BLACK }
, { "MOUSEWHEEL", TYPE_BLACK }
, { "MUTE", TYPE_BLACK }
, { "NEGOTIATIONNEEDED", TYPE_BLACK }
, { "NEXTTRACK", TYPE_BLACK }
, { "NOMATCH", TYPE_BLACK }
, { "NOUPDATE", TYPE_BLACK }
, { "OBSOLETE", TYPE_BLACK }
, { "OFFLINE", TYPE_BLACK }
, { "ONLINE", TYPE_BLACK }
, { "OPEN", TYPE_BLACK }
, { "ORIENTATIONCHANGE", TYPE_BLACK }
, { "OVERCONSTRAINED", TYPE_BLACK }
, { "OVERFLOWCHANGED", TYPE_BLACK }
, { "PAGEHIDE", TYPE_BLACK }
, { "PAGESHOW", TYPE_BLACK }
, { "PASTE", TYPE_BLACK }
, { "PAUSE", TYPE_BLACK }
, { "PAYERDETAILCHANGE", TYPE_BLACK }
, { "PAYMENTAUTHORIZED", TYPE_BLACK }
, { "PAYMENTMETHODCHANGE", TYPE_BLACK }
, { "PAYMENTMETHODSELECTED", TYPE_BLACK }
, { "PLAYING", TYPE_BLACK }
, { "PLAY", TYPE_BLACK }
, { "POINTERCANCEL", TYPE_BLACK }
, { "POINTERDOWN", TYPE_BLACK }
, { "POINTERENTER", TYPE_BLACK }
, { "POINTERLEAVE", TYPE_BLACK }
, { "POINTERLOCKCHANGE", TYPE_BLACK }
, { "POINTERLOCKERROR", TYPE_BLACK }
, { "POINTERMOVE", TYPE_BLACK }
, { "POINTEROUT", TYPE_BLACK }
, { "POINTEROVER", TYPE_BLACK }
, { "POINTERUP", TYPE_BLACK }
, { "POPSTATE", TYPE_BLACK }
, { "PREVIOUSTRACK", TYPE_BLACK }
, { "PROCESSORERROR", TYPE_BLACK }
, { "PROGRESS", TYPE_BLACK }
, { "PROPERTYCHANGE", TYPE_BLACK }
, { "RATECHANGE", TYPE_BLACK }
, { "READYSTATECHANGE", TYPE_BLACK }
, { "REJECTIONHANDLED", TYPE_BLACK }
, { "REMOVESOURCEBUFFER", TYPE_BLACK }
, { "REMOVESTREAM", TYPE_BLACK }
, { "REMOVETRACK", TYPE_BLACK }
, { "REMOVE", TYPE_BLACK }
, { "RESET", TYPE_BLACK }
, { "RESIZE", TYPE_BLACK }
, { "RESOURCETIMINGBUFFERFULL", TYPE_BLACK }
, { "RESULT", TYPE_BLACK }
, { "RESUME", TYPE_BLACK }
, { "SCROLL", TYPE_BLACK }
, { "SEARCH", TYPE_BLACK }
, { "SECURITYPOLICYVIOLATION", TYPE_BLACK }
, { "SEEKED", TYPE_BLACK }
, { "SEEKING", TYPE_BLACK }
, { "SELECTEND", TYPE_BLACK }
, { "SELECTIONCHANGE", TYPE_BLACK }
, { "SELECTSTART", TYPE_BLACK }
, { "SELECT", TYPE_BLACK }
, { "SHIPPINGADDRESSCHANGE", TYPE_BLACK }
, { "SHIPPINGCONTACTSELECTED", TYPE_BLACK }
, { "SHIPPINGMETHODSELECTED", TYPE_BLACK }
, { "SHIPPINGOPTIONCHANGE", TYPE_BLACK }
, { "SHOW", TYPE_BLACK }
, { "SIGNALINGSTATECHANGE", TYPE_BLACK }
, { "SLOTCHANGE", TYPE_BLACK }
, { "SOUNDEND", TYPE_BLACK }
, { "SOUNDSTART", TYPE_BLACK }
, { "SOURCECLOSE", TYPE_BLACK }
, { "SOURCEENDED", TYPE_BLACK }
, { "SOURCEOPEN", TYPE_BLACK }
, { "SPEECHEND", TYPE_BLACK }
, { "SPEECHSTART", TYPE_BLACK }
, { "SQUEEZEEND", TYPE_BLACK }
, { "SQUEEZESTART", TYPE_BLACK }
, { "SQUEEZE", TYPE_BLACK }
, { "STALLED", TYPE_BLACK }
, { "STARTED", TYPE_BLACK }
, { "START", TYPE_BLACK }
, { "STATECHANGE", TYPE_BLACK }
, { "STOP", TYPE_BLACK }
, { "STORAGE", TYPE_BLACK }
, { "SUBMIT", TYPE_BLACK }
, { "SUCCESS", TYPE_BLACK }
, { "SUSPEND", TYPE_BLACK }
, { "TEXTINPUT", TYPE_BLACK }
, { "TIMEOUT", TYPE_BLACK }
, { "TIMEUPDATE", TYPE_BLACK }
, { "TOGGLE", TYPE_BLACK }
, { "TOGGLE", TYPE_BLACK }
, { "TONECHANGE", TYPE_BLACK }
, { "TOUCHCANCEL", TYPE_BLACK }
, { "TOUCHEND", TYPE_BLACK }
, { "TOUCHFORCECHANGE", TYPE_BLACK }
, { "TOUCHMOVE", TYPE_BLACK }
, { "TOUCHSTART", TYPE_BLACK }
, { "TRACK", TYPE_BLACK }
, { "TRANSITIONCANCEL", TYPE_BLACK }
, { "TRANSITIONEND", TYPE_BLACK }
, { "TRANSITIONRUN", TYPE_BLACK }
, { "TRANSITIONSTART", TYPE_BLACK }
, { "UNCAPTUREDERROR", TYPE_BLACK }
, { "UNHANDLEDREJECTION", TYPE_BLACK }
, { "UNLOAD", TYPE_BLACK }
, { "UNMUTE", TYPE_BLACK }
, { "UPDATEEND", TYPE_BLACK }
, { "UPDATEFOUND", TYPE_BLACK }
, { "UPDATEREADY", TYPE_BLACK }
, { "UPDATESTART", TYPE_BLACK }
, { "UPDATE", TYPE_BLACK }
, { "UPGRADENEEDED", TYPE_BLACK }
, { "VALIDATEMERCHANT", TYPE_BLACK }
, { "VERSIONCHANGE", TYPE_BLACK }
, { "VISIBILITYCHANGE", TYPE_BLACK }
, { "VOLUMECHANGE", TYPE_BLACK }
, { "WAITINGFORKEY", TYPE_BLACK }
, { "WAITING", TYPE_BLACK }
, { "WEBGLCONTEXTCHANGED", TYPE_BLACK }
, { "WEBGLCONTEXTCREATIONERROR", TYPE_BLACK }
, { "WEBGLCONTEXTLOST", TYPE_BLACK }
, { "WEBGLCONTEXTRESTORED", TYPE_BLACK }
, { "WEBKITANIMATIONEND", TYPE_BLACK }
, { "WEBKITANIMATIONITERATION", TYPE_BLACK }
, { "WEBKITANIMATIONSTART", TYPE_BLACK }
, { "WEBKITBEFORETEXTINSERTED", TYPE_BLACK }
, { "WEBKITBEGINFULLSCREEN", TYPE_BLACK }
, { "WEBKITCURRENTPLAYBACKTARGETISWIRELESSCHANGED", TYPE_BLACK }
, { "WEBKITENDFULLSCREEN", TYPE_BLACK }
, { "WEBKITFULLSCREENCHANGE", TYPE_BLACK }
, { "WEBKITFULLSCREENERROR", TYPE_BLACK }
, { "WEBKITKEYADDED", TYPE_BLACK }
, { "WEBKITKEYERROR", TYPE_BLACK }
, { "WEBKITKEYMESSAGE", TYPE_BLACK }
, { "WEBKITMOUSEFORCECHANGED", TYPE_BLACK }
, { "WEBKITMOUSEFORCEDOWN", TYPE_BLACK }
, { "WEBKITMOUSEFORCEUP", TYPE_BLACK }
, { "WEBKITMOUSEFORCEWILLBEGIN", TYPE_BLACK }
, { "WEBKITNEEDKEY", TYPE_BLACK }
, { "WEBKITNETWORKINFOCHANGE", TYPE_BLACK }
, { "WEBKITPLAYBACKTARGETAVAILABILITYCHANGED", TYPE_BLACK }
, { "WEBKITPRESENTATIONMODECHANGED", TYPE_BLACK }
, { "WEBKITREGIONOVERSETCHANGE", TYPE_BLACK }
, { "WEBKITREMOVESOURCEBUFFER", TYPE_BLACK }
, { "WEBKITSOURCECLOSE", TYPE_BLACK }
, { "WEBKITSOURCEENDED", TYPE_BLACK }
, { "WEBKITSOURCEOPEN", TYPE_BLACK }
, { "WEBKITSPEECHCHANGE", TYPE_BLACK }
, { "WEBKITTRANSITIONEND", TYPE_BLACK }
, { "WEBKITWILLREVEALBOTTOM", TYPE_BLACK }
, { "WEBKITWILLREVEALLEFT", TYPE_BLACK }
, { "WEBKITWILLREVEALRIGHT", TYPE_BLACK }
, { "WEBKITWILLREVEALTOP", TYPE_BLACK }
, { "WHEEL", TYPE_BLACK }
, { "WRITEEND", TYPE_BLACK }
, { "WRITESTART", TYPE_BLACK }
, { "WRITE", TYPE_BLACK }
, { "ZOOM", TYPE_BLACK }
, { NULL, TYPE_NONE }
};
/*
* view-source:
* data:
* javascript:
*/
static stringtype_t BLACKATTR[] = {
{ "ACTION", TYPE_ATTR_URL } /* form */
, { "ATTRIBUTENAME", TYPE_ATTR_INDIRECT } /* SVG allow indirection of attribute names */
, { "BY", TYPE_ATTR_URL } /* SVG */
, { "BACKGROUND", TYPE_ATTR_URL } /* IE6, O11 */
, { "DATAFORMATAS", TYPE_BLACK } /* IE */
, { "DATASRC", TYPE_BLACK } /* IE */
, { "DYNSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
, { "FILTER", TYPE_STYLE } /* Opera, SVG inline style */
, { "FORMACTION", TYPE_ATTR_URL } /* HTML 5 */
, { "FOLDER", TYPE_ATTR_URL } /* Only on A tags, IE-only */
, { "FROM", TYPE_ATTR_URL } /* SVG */
, { "HANDLER", TYPE_ATTR_URL } /* SVG Tiny, Opera */
, { "HREF", TYPE_ATTR_URL }
, { "LOWSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
, { "POSTER", TYPE_ATTR_URL } /* Opera 10,11 */
, { "SRC", TYPE_ATTR_URL }
, { "STYLE", TYPE_STYLE }
, { "TO", TYPE_ATTR_URL } /* SVG */
, { "VALUES", TYPE_ATTR_URL } /* SVG */
, { "XLINK:HREF", TYPE_ATTR_URL }
, { NULL, TYPE_NONE }
};
/* xmlns */
/* `xml-stylesheet` > <eval>, <if expr=> */
/*
static const char* BLACKATTR[] = {
"ATTRIBUTENAME",
"BACKGROUND",
"DATAFORMATAS",
"HREF",
"SCROLL",
"SRC",
"STYLE",
"SRCDOC",
NULL
};
*/
static const char* BLACKTAG[] = {
"APPLET"
/* , "AUDIO" */
, "BASE"
, "COMMENT" /* IE http://html5sec.org/#38 */
, "EMBED"
/* , "FORM" */
, "FRAME"
, "FRAMESET"
, "HANDLER" /* Opera SVG, effectively a script tag */
, "IFRAME"
, "IMPORT"
, "ISINDEX"
, "LINK"
, "LISTENER"
/* , "MARQUEE" */
, "META"
, "NOSCRIPT"
, "OBJECT"
, "SCRIPT"
, "STYLE"
/* , "VIDEO" */
, "VMLFRAME"
, "XML"
, "XSS"
, NULL
};
static int cstrcasecmp_with_null(const char *a, const char *b, size_t n)
{
char ca;
char cb;
/* printf("Comparing to %s %.*s\n", a, (int)n, b); */
while (n-- > 0) {
cb = *b++;
if (cb == '\0') continue;
ca = *a++;
if (cb >= 'a' && cb <= 'z') {
cb -= 0x20;
}
/* printf("Comparing %c vs %c with %d left\n", ca, cb, (int)n); */
if (ca != cb) {
return 1;
}
}
if (*a == 0) {
/* printf(" MATCH \n"); */
return 0;
} else {
return 1;
}
}
/*
* Does an HTML encoded binary string (const char*, length) start with
* a all uppercase c-string (null terminated), case insensitive!
*
* also ignore any embedded nulls in the HTML string!
*
* return 1 if match / starts with
* return 0 if not
*/
static int htmlencode_startswith(const char *a, const char *b, size_t n)
{
size_t consumed;
int cb;
int first = 1;
/* printf("Comparing %s with %.*s\n", a,(int)n,b); */
while (n > 0) {
if (*a == 0) {
/* printf("Match EOL!\n"); */
return 1;
}
cb = html_decode_char_at(b, n, &consumed);
b += consumed;
n -= consumed;
if (first && cb <= 32) {
/* ignore all leading whitespace and control characters */
continue;
}
first = 0;
if (cb == 0) {
/* always ignore null characters in user input */
continue;
}
if (cb == 10) {
/* always ignore vertical tab characters in user input */
/* who allows this?? */
continue;
}
if (cb >= 'a' && cb <= 'z') {
/* upcase */
cb -= 0x20;
}
if (*a != (char) cb) {
/* printf(" %c != %c\n", *a, cb); */
/* mismatch */
return 0;
}
a++;
}
return (*a == 0) ? 1 : 0;
}
static int is_black_tag(const char* s, size_t len)
{
const char** black;
if (len < 3) {
return 0;
}
black = BLACKTAG;
while (*black != NULL) {
if (cstrcasecmp_with_null(*black, s, len) == 0) {
/* printf("Got black tag %s\n", *black); */
return 1;
}
black += 1;
}
/* anything SVG related */
if ((s[0] == 's' || s[0] == 'S') &&
(s[1] == 'v' || s[1] == 'V') &&
(s[2] == 'g' || s[2] == 'G')) {
/* printf("Got SVG tag \n"); */
return 1;
}
/* Anything XSL(t) related */
if ((s[0] == 'x' || s[0] == 'X') &&
(s[1] == 's' || s[1] == 'S') &&
(s[2] == 'l' || s[2] == 'L')) {
/* printf("Got XSL tag\n"); */
return 1;
}
return 0;
}
static attribute_t is_black_attr(const char* s, size_t len)
{
stringtype_t* black;
if (len < 2) {
return TYPE_NONE;
}
if (len >= 5) {
/* JavaScript on.* event handlers */
if ((s[0] == 'o' || s[0] == 'O') && (s[1] == 'n' || s[1] == 'N')) {
black = BLACKATTREVENT;
const char *s_without_on = &s[2]; // start comparing from the third char
while (black->name != NULL) {
if (cstrcasecmp_with_null(black->name, s_without_on, strlen(black->name)) == 0) {
/* printf("Got banned attribute name %s\n", black->name); */
return black->atype;
}
black += 1;
}
}
/* XMLNS can be used to create arbitrary tags */
// goedge: commented for photo uploading
//if (cstrcasecmp_with_null("XMLNS", s, 5) == 0 || cstrcasecmp_with_null("XLINK", s, 5) == 0) {
/* printf("Got XMLNS and XLINK tags\n"); */
// return TYPE_BLACK;
//}
}
black = BLACKATTR;
while (black->name != NULL) {
if (cstrcasecmp_with_null(black->name, s, len) == 0) {
/* printf("Got banned attribute name %s\n", black->name); */
return black->atype;
}
black += 1;
}
return TYPE_NONE;
}
static int is_black_url(const char* s, size_t len)
{
static const char* data_url = "DATA";
static const char* viewsource_url = "VIEW-SOURCE";
/* obsolete but interesting signal */
static const char* vbscript_url = "VBSCRIPT";
/* covers JAVA, JAVASCRIPT, + colon */
static const char* javascript_url = "JAVA";
/* skip whitespace */
while (len > 0 && (*s <= 32 || *s >= 127)) {
/*
* HEY: this is a signed character.
* We are intentionally skipping high-bit characters too
* since they are not ASCII, and Opera sometimes uses UTF-8 whitespace.
*
* Also in EUC-JP some of the high bytes are just ignored.
*/
++s;
--len;
}
if (htmlencode_startswith(data_url, s, len)) {
return 1;
}
if (htmlencode_startswith(viewsource_url, s, len)) {
return 1;
}
if (htmlencode_startswith(javascript_url, s, len)) {
return 1;
}
if (htmlencode_startswith(vbscript_url, s, len)) {
return 1;
}
return 0;
}
int libinjection_is_xss(const char* s, size_t len, int flags)
{
h5_state_t h5;
attribute_t attr = TYPE_NONE;
libinjection_h5_init(&h5, s, len, (enum html5_flags) flags);
while (libinjection_h5_next(&h5)) {
if (h5.token_type != ATTR_VALUE) {
attr = TYPE_NONE;
}
if (h5.token_type == DOCTYPE) {
return 1;
} else if (h5.token_type == TAG_NAME_OPEN) {
if (is_black_tag(h5.token_start, h5.token_len)) {
return 1;
}
} else if (h5.token_type == ATTR_NAME) {
attr = is_black_attr(h5.token_start, h5.token_len);
} else if (h5.token_type == ATTR_VALUE) {
/*
* IE6,7,8 parsing works a bit differently so
* a whole <script> or other black tag might be hiding
* inside an attribute value under HTML 5 parsing
* See http://html5sec.org/#102
* to avoid doing a full reparse of the value, just
* look for "<". This probably need adjusting to
* handle escaped characters
*/
/*
if (memchr(h5.token_start, '<', h5.token_len) != NULL) {
return 1;
}
*/
switch (attr) {
case TYPE_NONE:
break;
case TYPE_BLACK:
return 1;
case TYPE_ATTR_URL:
if (is_black_url(h5.token_start, h5.token_len)) {
return 1;
}
break;
case TYPE_STYLE:
return 1;
case TYPE_ATTR_INDIRECT:
/* an attribute name is specified in a _value_ */
if (is_black_attr(h5.token_start, h5.token_len)) {
return 1;
}
break;
/*
default:
assert(0);
*/
}
attr = TYPE_NONE;
} else if (h5.token_type == TAG_COMMENT) {
/* IE uses a "`" as a tag ending char */
// goedge: commented for photo uploading
/**if (memchr(h5.token_start, '`', h5.token_len) != NULL) {
return 1;
}**/
/* IE conditional comment */
if (h5.token_len > 3) {
if (h5.token_start[0] == '[' &&
(h5.token_start[1] == 'i' || h5.token_start[1] == 'I') &&
(h5.token_start[2] == 'f' || h5.token_start[2] == 'F')) {
return 1;
}
if ((h5.token_start[0] == 'x' || h5.token_start[0] == 'X') &&
(h5.token_start[1] == 'm' || h5.token_start[1] == 'M') &&
(h5.token_start[2] == 'l' || h5.token_start[2] == 'L')) {
return 1;
}
}
if (h5.token_len > 5) {
/* IE <?import pseudo-tag */
if (cstrcasecmp_with_null("IMPORT", h5.token_start, 6) == 0) {
return 1;
}
/* XML Entity definition */
if (cstrcasecmp_with_null("ENTITY", h5.token_start, 6) == 0) {
return 1;
}
}
}
}
return 0;
}
/*
* wrapper
*
*
* const char* s: input string, may contain nulls, does not need to be null-terminated.
* size_t len: input string length.
*
*
*/
int libinjection_xss(const char* s, size_t slen)
{
if (libinjection_is_xss(s, slen, DATA_STATE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_NO_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_SINGLE_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_DOUBLE_QUOTE)) {
return 1;
}
if (libinjection_is_xss(s, slen, VALUE_BACK_QUOTE)) {
return 1;
}
return 0;
}

View File

@@ -0,0 +1,21 @@
#ifndef LIBINJECTION_XSS
#define LIBINJECTION_XSS
#ifdef __cplusplus
extern "C" {
#endif
/**
* HEY THIS ISN'T DONE
*/
/* pull in size_t */
#include <string.h>
int libinjection_is_xss(const char* s, size_t len, int flags);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,313 @@
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <assert.h>
#include "libinjection.h"
#include "libinjection_sqli.h"
#include "libinjection_xss.h"
#ifndef TRUE
#define TRUE 1
#endif
#ifndef FALSE
#define FALSE 0
#endif
static int g_test_ok = 0;
static int g_test_fail = 0;
typedef enum {
MODE_SQLI,
MODE_XSS
} detect_mode_t;
static void usage(const char* program_name);
size_t modp_rtrim(char* str, size_t len);
void modp_toprint(char* str, size_t len);
void test_positive(FILE * fd, const char *fname, detect_mode_t mode,
int flag_invert, int flag_true, int flag_quiet);
int urlcharmap(char ch);
size_t modp_url_decode(char* dest, const char* s, size_t len);
int urlcharmap(char ch) {
switch (ch) {
case '0': return 0;
case '1': return 1;
case '2': return 2;
case '3': return 3;
case '4': return 4;
case '5': return 5;
case '6': return 6;
case '7': return 7;
case '8': return 8;
case '9': return 9;
case 'a': case 'A': return 10;
case 'b': case 'B': return 11;
case 'c': case 'C': return 12;
case 'd': case 'D': return 13;
case 'e': case 'E': return 14;
case 'f': case 'F': return 15;
default:
return 256;
}
}
size_t modp_url_decode(char* dest, const char* s, size_t len)
{
const char* deststart = dest;
size_t i = 0;
int d = 0;
while (i < len) {
switch (s[i]) {
case '+':
*dest++ = ' ';
i += 1;
break;
case '%':
if (i+2 < len) {
d = (urlcharmap(s[i+1]) << 4) | urlcharmap(s[i+2]);
if ( d < 256) {
*dest = (char) d;
dest++;
i += 3; /* loop will increment one time */
} else {
*dest++ = '%';
i += 1;
}
} else {
*dest++ = '%';
i += 1;
}
break;
default:
*dest++ = s[i];
i += 1;
}
}
*dest = '\0';
return (size_t)(dest - deststart); /* compute "strlen" of dest */
}
void modp_toprint(char* str, size_t len)
{
size_t i;
for (i = 0; i < len; ++i) {
if (str[i] < 32 || str[i] > 126) {
str[i] = '?';
}
}
}
size_t modp_rtrim(char* str, size_t len)
{
while (len) {
char c = str[len -1];
if (c == ' ' || c == '\n' || c == '\t' || c == '\r') {
str[len -1] = '\0';
len -= 1;
} else {
break;
}
}
return len;
}
void test_positive(FILE * fd, const char *fname, detect_mode_t mode,
int flag_invert, int flag_true, int flag_quiet)
{
char linebuf[8192];
int issqli = 0;
int linenum = 0;
size_t len;
sfilter sf;
while (fgets(linebuf, sizeof(linebuf), fd)) {
linenum += 1;
len = modp_rtrim(linebuf, strlen(linebuf));
if (len == 0) {
continue;
}
if (linebuf[0] == '#') {
continue;
}
len = modp_url_decode(linebuf, linebuf, len);
switch (mode) {
case MODE_SQLI: {
libinjection_sqli_init(&sf, linebuf, len, 0);
issqli = libinjection_is_sqli(&sf);
break;
}
case MODE_XSS: {
issqli = libinjection_xss(linebuf, len);
break;
}
default:
assert(0);
}
if (issqli) {
g_test_ok += 1;
} else {
g_test_fail += 1;
}
if (!flag_quiet) {
if ((issqli && flag_true && ! flag_invert) ||
(!issqli && flag_true && flag_invert) ||
!flag_true) {
modp_toprint(linebuf, len);
switch (mode) {
case MODE_SQLI: {
/*
* if we didn't find a SQLi and fingerprint from
* sqlstats is is 'sns' or 'snsns' then redo using
* plain context
*/
if (!issqli && (strcmp(sf.fingerprint, "sns") == 0 ||
strcmp(sf.fingerprint, "snsns") == 0)) {
libinjection_sqli_fingerprint(&sf, 0);
}
fprintf(stdout, "%s\t%d\t%s\t%s\t%s\n",
fname, linenum,
(issqli ? "True" : "False"), sf.fingerprint, linebuf);
break;
}
case MODE_XSS: {
fprintf(stdout, "%s\t%d\t%s\t%s\n",
fname, linenum,
(issqli ? "True" : "False"), linebuf);
break;
}
default:
assert(0);
}
}
}
}
}
static void usage(const char* program_name)
{
fprintf(stdout, "usage: %s [flags] [files...]\n", program_name);
fprintf(stdout, "%s\n", "");
fprintf(stdout, "%s\n", "-q --quiet : quiet mode");
fprintf(stdout, "%s\n", "-m --max-fails : number of failed cases need to fail entire test");
fprintf(stdout, "%s\n", "-s INTEGER : repeat each test N time "
"(for performance testing)");
fprintf(stdout, "%s\n", "-t : only print positive matches");
fprintf(stdout, "%s\n", "-x --mode-xss : test input for XSS");
fprintf(stdout, "%s\n", "-i --invert : invert test logic "
"(input is tested for being safe)");
fprintf(stdout, "%s\n", "");
fprintf(stdout, "%s\n", "-? -h -help --help : this page");
fprintf(stdout, "%s\n", "");
}
int main(int argc, const char *argv[])
{
/*
* invert output, by
*/
int flag_invert = FALSE;
/*
* don't print anything.. useful for
* performance monitors, gprof.
*/
int flag_quiet = FALSE;
/*
* only print positive results
* with invert, only print negative results
*/
int flag_true = FALSE;
detect_mode_t mode = MODE_SQLI;
int flag_slow = 1;
int count = 0;
int max = -1;
int i, j;
int offset = 1;
while (offset < argc) {
if (strcmp(argv[offset], "-?") == 0 ||
strcmp(argv[offset], "-h") == 0 ||
strcmp(argv[offset], "-help") == 0 ||
strcmp(argv[offset], "--help") == 0) {
usage(argv[0]);
exit(0);
}
if (strcmp(argv[offset], "-i") == 0) {
offset += 1;
flag_invert = TRUE;
} else if (strcmp(argv[offset], "-q") == 0 ||
strcmp(argv[offset], "--quiet") == 0) {
offset += 1;
flag_quiet = TRUE;
} else if (strcmp(argv[offset], "-t") == 0) {
offset += 1;
flag_true = TRUE;
} else if (strcmp(argv[offset], "-s") == 0) {
offset += 1;
flag_slow = 100;
} else if (strcmp(argv[offset], "-m") == 0 ||
strcmp(argv[offset], "--max-fails") == 0) {
offset += 1;
max = atoi(argv[offset]);
offset += 1;
} else if (strcmp(argv[offset], "-x") == 0 ||
strcmp(argv[offset], "--mode-xss") == 0) {
mode = MODE_XSS;
offset += 1;
} else {
break;
}
}
if (offset == argc) {
test_positive(stdin, "stdin", mode, flag_invert, flag_true, flag_quiet);
} else {
for (j = 0; j < flag_slow; ++j) {
for (i = offset; i < argc; ++i) {
FILE* fd = fopen(argv[i], "r");
if (fd) {
test_positive(fd, argv[i], mode, flag_invert, flag_true, flag_quiet);
fclose(fd);
}
}
}
}
if (!flag_quiet) {
fprintf(stdout, "%s", "\n");
fprintf(stdout, "SQLI : %d\n", g_test_ok);
fprintf(stdout, "SAFE : %d\n", g_test_fail);
fprintf(stdout, "TOTAL : %d\n", g_test_ok + g_test_fail);
}
if (max == -1) {
return 0;
}
count = g_test_ok;
if (flag_invert) {
count = g_test_fail;
}
if (count > max) {
printf("\nThreshold is %d, got %d, failing.\n", max, count);
return 1;
} else {
printf("\nThreshold is %d, got %d, passing.\n", max, count);
return 0;
}
}

View File

@@ -0,0 +1,165 @@
/**
* Copyright 2012, 2013 Nick Galbreath
* nickg@client9.com
* BSD License -- see COPYING.txt for details
*
* This is for testing against files in ../data/ *.txt
* Reads from stdin or a list of files, and emits if a line
* is a SQLi attack or not, and does basic statistics
*
*/
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include "libinjection.h"
#include "libinjection_sqli.h"
void print_string(stoken_t* t);
void print_var(stoken_t* t);
void print_token(stoken_t *t);
void usage(void);
void print_string(stoken_t* t)
{
/* print opening quote */
if (t->str_open != '\0') {
printf("%c", t->str_open);
}
/* print content */
printf("%s", t->val);
/* print closing quote */
if (t->str_close != '\0') {
printf("%c", t->str_close);
}
}
void print_var(stoken_t* t)
{
if (t->count >= 1) {
printf("%c", '@');
}
if (t->count == 2) {
printf("%c", '@');
}
print_string(t);
}
void print_token(stoken_t *t) {
printf("%c ", t->type);
switch (t->type) {
case 's':
print_string(t);
break;
case 'v':
print_var(t);
break;
default:
printf("%s", t->val);
}
printf("%s", "\n");
}
void usage(void) {
printf("\n");
printf("libinjection sqli tester\n");
printf("\n");
printf(" -ca parse as ANSI SQL\n");
printf(" -cm parse as MySQL SQL\n");
printf(" -q0 parse as is\n");
printf(" -q1 parse in single-quote mode\n");
printf(" -q2 parse in doiuble-quote mode\n");
printf("\n");
printf(" -f --fold fold results\n");
printf("\n");
printf(" -d --detect detect SQLI. empty reply = not detected\n");
printf("\n");
}
int main(int argc, const char* argv[])
{
size_t slen;
char* copy;
int flags = 0;
int fold = 0;
int detect = 0;
int i;
int count;
int offset = 1;
int issqli;
sfilter sf;
if (argc < 2) {
usage();
return 1;
}
while (1) {
if (strcmp(argv[offset], "-h") == 0 || strcmp(argv[offset], "-?") == 0 || strcmp(argv[offset], "--help") == 0) {
usage();
return 1;
}
if (strcmp(argv[offset], "-m") == 0) {
flags |= FLAG_SQL_MYSQL;
offset += 1;
}
else if (strcmp(argv[offset], "-f") == 0 || strcmp(argv[offset], "--fold") == 0) {
fold = 1;
offset += 1;
} else if (strcmp(argv[offset], "-d") == 0 || strcmp(argv[offset], "--detect") == 0) {
detect = 1;
offset += 1;
} else if (strcmp(argv[offset], "-ca") == 0) {
flags |= FLAG_SQL_ANSI;
offset += 1;
} else if (strcmp(argv[offset], "-cm") == 0) {
flags |= FLAG_SQL_MYSQL;
offset += 1;
} else if (strcmp(argv[offset], "-q0") == 0) {
flags |= FLAG_QUOTE_NONE;
offset += 1;
} else if (strcmp(argv[offset], "-q1") == 0) {
flags |= FLAG_QUOTE_SINGLE;
offset += 1;
} else if (strcmp(argv[offset], "-q2") == 0) {
flags |= FLAG_QUOTE_DOUBLE;
offset += 1;
} else {
break;
}
}
/* ATTENTION: argv is a C-string, null terminated. We copy this
* to it's own location, WITHOUT null byte. This way, valgrind
* can see if we run past the buffer.
*/
slen = strlen(argv[offset]);
copy = (char* ) malloc(slen);
memcpy(copy, argv[offset], slen);
libinjection_sqli_init(&sf, copy, slen, flags);
if (detect == 1) {
issqli = libinjection_is_sqli(&sf);
if (issqli) {
printf("%s\n", sf.fingerprint);
}
} else if (fold == 1) {
count = libinjection_sqli_fold(&sf);
for (i = 0; i < count; ++i) {
print_token(&(sf.tokenvec[i]));
}
} else {
while (libinjection_sqli_tokenize(&sf)) {
print_token(sf.current);
}
}
free(copy);
return 0;
}

View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
#
# Copyright 2012, 2013 Nick Galbreath
# nickg@client9.com
# BSD License -- see COPYING.txt for details
#
"""
Converts a libinjection JSON data file to a C header (.h) file
"""
import sys
def toc(obj):
""" main routine """
print("""
#ifndef LIBINJECTION_SQLI_DATA_H
#define LIBINJECTION_SQLI_DATA_H
#include "libinjection.h"
#include "libinjection_sqli.h"
typedef struct {
const char *word;
char type;
} keyword_t;
static size_t parse_money(sfilter * sf);
static size_t parse_other(sfilter * sf);
static size_t parse_white(sfilter * sf);
static size_t parse_operator1(sfilter *sf);
static size_t parse_char(sfilter *sf);
static size_t parse_hash(sfilter *sf);
static size_t parse_dash(sfilter *sf);
static size_t parse_slash(sfilter *sf);
static size_t parse_backslash(sfilter * sf);
static size_t parse_operator2(sfilter *sf);
static size_t parse_string(sfilter *sf);
static size_t parse_word(sfilter * sf);
static size_t parse_var(sfilter * sf);
static size_t parse_number(sfilter * sf);
static size_t parse_tick(sfilter * sf);
static size_t parse_ustring(sfilter * sf);
static size_t parse_qstring(sfilter * sf);
static size_t parse_nqstring(sfilter * sf);
static size_t parse_xstring(sfilter * sf);
static size_t parse_bstring(sfilter * sf);
static size_t parse_estring(sfilter * sf);
static size_t parse_bword(sfilter * sf);
""")
#
# Mapping of character to function
#
fnmap = {
'CHAR_WORD' : 'parse_word',
'CHAR_WHITE': 'parse_white',
'CHAR_OP1' : 'parse_operator1',
'CHAR_UNARY': 'parse_operator1',
'CHAR_OP2' : 'parse_operator2',
'CHAR_BANG' : 'parse_operator2',
'CHAR_BACK' : 'parse_backslash',
'CHAR_DASH' : 'parse_dash',
'CHAR_STR' : 'parse_string',
'CHAR_HASH' : 'parse_hash',
'CHAR_NUM' : 'parse_number',
'CHAR_SLASH': 'parse_slash',
'CHAR_SEMICOLON' : 'parse_char',
'CHAR_COMMA': 'parse_char',
'CHAR_LEFTPARENS': 'parse_char',
'CHAR_RIGHTPARENS': 'parse_char',
'CHAR_LEFTBRACE': 'parse_char',
'CHAR_RIGHTBRACE': 'parse_char',
'CHAR_VAR' : 'parse_var',
'CHAR_OTHER': 'parse_other',
'CHAR_MONEY': 'parse_money',
'CHAR_TICK' : 'parse_tick',
'CHAR_UNDERSCORE': 'parse_underscore',
'CHAR_USTRING' : 'parse_ustring',
'CHAR_QSTRING' : 'parse_qstring',
'CHAR_NQSTRING' : 'parse_nqstring',
'CHAR_XSTRING' : 'parse_xstring',
'CHAR_BSTRING' : 'parse_bstring',
'CHAR_ESTRING' : 'parse_estring',
'CHAR_BWORD' : 'parse_bword'
}
print()
print("typedef size_t (*pt2Function)(sfilter *sf);")
print("static const pt2Function char_parse_map[] = {")
pos = 0
for character in obj['charmap']:
print(" &%s, /* %d */" % (fnmap[character], pos))
pos += 1
print("};")
print()
# keywords
# load them
keywords = obj['keywords']
for fingerprint in list(obj['fingerprints']):
fingerprint = '0' + fingerprint.upper()
keywords[fingerprint] = 'F'
needhelp = []
for key in keywords.keys():
if key != key.upper():
needhelp.append(key)
for key in needhelp:
tmpv = keywords[key]
del keywords[key]
keywords[key.upper()] = tmpv
print("static const keyword_t sql_keywords[] = {")
for k in sorted(keywords.keys()):
if len(k) > 31:
sys.stderr.write("ERROR: keyword greater than 32 chars\n")
sys.exit(1)
print(" {\"%s\", '%s'}," % (k, keywords[k]))
print("};")
print("static const size_t sql_keywords_sz = %d;" % (len(keywords), ))
print("#endif")
return 0
if __name__ == '__main__':
import json
sys.exit(toc(json.load(sys.stdin)))

View File

@@ -0,0 +1,3 @@
#define LIBINJECTION_VERSION "3.9.1"
#include "libinjection/src/libinjection_sqli.c"

View File

@@ -0,0 +1,6 @@
#define LIBINJECTION_VERSION "3.9.1"
#include "libinjection/src/libinjection_xss.c"
#include "libinjection/src/libinjection_html5.c"
#define GOEDGE_VERSION "23" // last version is for GoEdge change

View File

@@ -0,0 +1,93 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils
/*
#cgo CFLAGS: -I./libinjection/src
#include <libinjection.h>
#include <stdlib.h>
*/
import "C"
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url"
"strconv"
"strings"
"unsafe"
)
// DetectSQLInjectionCache detect sql injection in string with cache
func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 128 || l > utils.MaxCacheDataSize {
return DetectSQLInjection(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@SQLI@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectSQLInjection(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectSQLInjection detect sql injection in string
func DetectSQLInjection(input string) bool {
if len(input) == 0 {
return false
}
if detectSQLInjectionOne(input) {
return true
}
// 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?")
if argsIndex > 0 {
var args = input[argsIndex+1:]
unescapeArgs, err := url.QueryUnescape(args)
if err == nil && args != unescapeArgs {
return detectSQLInjectionOne(args) || detectSQLInjectionOne(unescapeArgs)
} else {
return detectSQLInjectionOne(args)
}
}
} else {
unescapedInput, err := url.QueryUnescape(input)
if err == nil && input != unescapedInput {
return detectSQLInjectionOne(unescapedInput)
}
}
return false
}
func detectSQLInjectionOne(input string) bool {
if len(input) == 0 {
return false
}
var fingerprint [8]C.char
var fingerprintPtr = (*C.char)(unsafe.Pointer(&fingerprint[0]))
var cInput = C.CString(input)
defer C.free(unsafe.Pointer(cInput))
return C.libinjection_sqli(cInput, C.size_t(len(input)), fingerprintPtr) == 1
}

View File

@@ -0,0 +1,128 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"runtime"
"strings"
"testing"
)
func TestDetectSQLInjection(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(injectionutils.DetectSQLInjection("' UNION SELECT * FROM myTable"))
a.IsTrue(injectionutils.DetectSQLInjection("id=1 ' UNION select * from a"))
a.IsTrue(injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--"))
a.IsFalse(injectionutils.DetectSQLInjection("' UNION SELECT1 * FROM myTable"))
a.IsFalse(injectionutils.DetectSQLInjection("1234"))
a.IsFalse(injectionutils.DetectSQLInjection(""))
a.IsTrue(injectionutils.DetectSQLInjection("id=123 OR 1=1&b=2"))
a.IsTrue(injectionutils.DetectSQLInjection("id=123&b=456&c=1' or 2=2"))
a.IsFalse(injectionutils.DetectSQLInjection("?"))
a.IsFalse(injectionutils.DetectSQLInjection("/hello?age=22"))
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1"))
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1"))
}
func BenchmarkDetectSQLInjection(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--")
}
})
}
func BenchmarkDetectSQLInjection_URL(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1")
}
})
}
func BenchmarkDetectSQLInjection_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=" + types.String(rands.Int64()%10000))
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Middle(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=" + types.String(rands.Int64()%10000) + "&s=" + s + "&v=%20")
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1")
}
})
}

View File

@@ -0,0 +1,90 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils
/*
#cgo CFLAGS: -I./libinjection/src
#include <libinjection.h>
#include <stdlib.h>
*/
import "C"
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url"
"strconv"
"strings"
"unsafe"
)
func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 512 || l > utils.MaxCacheDataSize {
return DetectXSS(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@XSS@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectXSS(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectXSS detect XSS in string
func DetectXSS(input string) bool {
if len(input) == 0 {
return false
}
if detectXSSOne(input) {
return true
}
// 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?")
if argsIndex > 0 {
var args = input[argsIndex+1:]
unescapeArgs, err := url.QueryUnescape(args)
if err == nil && args != unescapeArgs {
return detectXSSOne(args) || detectXSSOne(unescapeArgs)
} else {
return detectXSSOne(args)
}
}
} else {
unescapedInput, err := url.QueryUnescape(input)
if err == nil && input != unescapedInput {
return detectXSSOne(unescapedInput)
}
}
return false
}
func detectXSSOne(input string) bool {
if len(input) == 0 {
return false
}
var cInput = C.CString(input)
defer C.free(unsafe.Pointer(cInput))
return C.libinjection_xss(cInput, C.size_t(len(input))) == 1
}

View File

@@ -0,0 +1,80 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package injectionutils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert"
"runtime"
"testing"
)
func TestDetectXSS(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(injectionutils.DetectXSS(""))
a.IsFalse(injectionutils.DetectXSS("abc"))
a.IsTrue(injectionutils.DetectXSS("<script>"))
a.IsTrue(injectionutils.DetectXSS("<link>"))
a.IsFalse(injectionutils.DetectXSS("<html><span>"))
a.IsFalse(injectionutils.DetectXSS("&lt;script&gt;"))
a.IsTrue(injectionutils.DetectXSS("/path?onmousedown=a"))
a.IsTrue(injectionutils.DetectXSS("/path?onkeyup=a"))
a.IsTrue(injectionutils.DetectXSS("onkeyup=a"))
a.IsTrue(injectionutils.DetectXSS("<iframe scrolling='no'>"))
a.IsFalse(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>"))
a.IsTrue(injectionutils.DetectXSS("name=s&description=%3Cscript+src%3D%22a.js%22%3Edddd%3C%2Fscript%3E"))
a.IsFalse(injectionutils.DetectXSS(`<x:xmpmeta xmlns:x="adobe:ns:meta/" x:xmptk="XMP Core 6.0.0">
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<rdf:Description rdf:about=""
xmlns:tiff="http://ns.adobe.com/tiff/1.0/">
<tiff:Orientation>1</tiff:Orientation>
</rdf:Description>
</rdf:RDF>
</x:xmpmeta>`)) // included in some photo files
}
func BenchmarkDetectXSS_MISS(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
}
})
}
func BenchmarkDetectXSS_MISS_Cache(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSSCache("<html><body><span>RequestId: 1234567890</span></body></html>", utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectXSS_HIT(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
if !result {
b.Fatal("'result' should not be 'false'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
}
})
}

View File

@@ -3,12 +3,18 @@
package waf
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"os"
"sync"
"sync/atomic"
)
@@ -25,11 +31,30 @@ const (
const IPTypeAll = "*"
func init() {
if !teaconst.IsMain {
return
}
var cacheFile = Tea.Root + "/data/waf_white_list.cache"
// save
events.On(events.EventTerminated, func() {
_ = SharedIPWhiteList.Save(cacheFile)
})
// load
go func() {
_ = SharedIPWhiteList.Load(cacheFile)
_ = os.Remove(cacheFile)
}()
}
// IPList IP列表管理
type IPList struct {
expireList *expires.List
ipMap map[string]uint64 // ip => id
idMap map[uint64]string // id => ip
ipMap map[string]uint64 // ip info => id
idMap map[uint64]string // id => ip info
listType IPListType
id uint64
@@ -47,7 +72,7 @@ func NewIPList(listType IPListType) *IPList {
listType: listType,
}
e := expires.NewList()
var e = expires.NewList()
list.expireList = e
e.OnGC(func(itemId uint64) {
@@ -206,6 +231,85 @@ func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
}
}
// Save to local file
func (this *IPList) Save(path string) error {
var itemMaps = []maps.Map{} // [ {ip info, expiresAt }, ... ]
this.locker.Lock()
defer this.locker.Unlock()
// prevent too many items
if len(this.ipMap) > 100_000 {
return nil
}
for ipInfo, id := range this.ipMap {
var expiresAt = this.expireList.ExpiresAt(id)
if expiresAt <= 0 {
continue
}
itemMaps = append(itemMaps, maps.Map{
"ip": ipInfo,
"expiresAt": expiresAt,
})
}
itemMapsJSON, err := json.Marshal(itemMaps)
if err != nil {
return err
}
return os.WriteFile(path, itemMapsJSON, 0666)
}
// Load from local file
func (this *IPList) Load(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if len(data) == 0 {
return nil
}
var itemMaps = []maps.Map{}
err = json.Unmarshal(data, &itemMaps)
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
for _, itemMap := range itemMaps {
var ip = itemMap.GetString("ip")
var expiresAt = itemMap.GetInt64("expiresAt")
if len(ip) == 0 || expiresAt < fasttime.Now().Unix()+10 /** seconds **/ {
continue
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
this.ipMap[ip] = id
this.idMap[id] = ip
}
return nil
}
// IPMap get ipMap
func (this *IPList) IPMap() map[string]uint64 {
this.locker.RLock()
defer this.locker.RUnlock()
return this.ipMap
}
// IdMap get idMap
func (this *IPList) IdMap() map[uint64]string {
this.locker.RLock()
defer this.locker.RUnlock()
return this.idMap
}
func (this *IPList) remove(id uint64) {
this.locker.Lock()
ip, ok := this.idMap[id]

View File

@@ -1,12 +1,16 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
package waf_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
timeutil "github.com/iwind/TeaGo/utils/time"
"os"
"runtime"
"strconv"
"testing"
@@ -14,35 +18,33 @@ import (
)
func TestNewIPList(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
var list = waf.NewIPList(waf.IPListTypeDeny)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
list.RemoveIP("127.0.0.1", 1, false)
logs.PrintAsJSON(list.ipMap, t)
logs.PrintAsJSON(list.idMap, t)
logs.PrintAsJSON(list.IPMap(), t)
logs.PrintAsJSON(list.IdMap(), t)
}
func TestIPList_Expire(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
var list = waf.NewIPList(waf.IPListTypeDeny)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
var ticker = time.NewTicker(1 * time.Second)
for range ticker.C {
t.Log("====")
list.locker.Lock()
logs.PrintAsJSON(list.ipMap, t)
logs.PrintAsJSON(list.idMap, t)
list.locker.Unlock()
if len(list.idMap) == 0 {
logs.PrintAsJSON(list.IPMap(), t)
logs.PrintAsJSON(list.IdMap(), t)
if len(list.IdMap()) == 0 {
break
}
}
@@ -51,54 +53,78 @@ func TestIPList_Expire(t *testing.T) {
func TestIPList_Contains(t *testing.T) {
var a = assert.NewAssertion(t)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
//list.RemoveIP("192.168.1.100")
{
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
a.IsTrue(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
}
{
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
a.IsFalse(list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
}
}
func TestIPList_ContainsExpires(t *testing.T) {
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
// list.RemoveIP("192.168.1.100", 1, false)
for _, ip := range []string{"192.168.1.100", "192.168.2.100"} {
expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
expiresAt, ok := list.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt))
}
}
func TestIPList_Save(t *testing.T) {
var a = assert.NewAssertion(t)
var list = waf.NewIPList(waf.IPListTypeAllow)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100", time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 2, "192.168.1.101", time.Now().Unix()+3600)
var file = Tea.Root + "/data/waf.iplist.json"
err := list.Save(file)
if err != nil {
t.Fatal(err)
}
var newList = waf.NewIPList(waf.IPListTypeAllow)
err = newList.Load(file)
if err != nil {
t.Fatal(err)
}
a.IsTrue(newList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
_ = os.Remove(file)
}
func BenchmarkIPList_Add(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
for i := 0; i < b.N; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
b.Log(len(list.ipMap))
b.Log(len(list.IPMap()))
}
func BenchmarkIPList_Has(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = NewIPList(IPListTypeDeny)
var list = waf.NewIPList(waf.IPListTypeDeny)
b.ResetTimer()
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
for i := 0; i < b.N; i++ {
list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
list.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
}
}

View File

@@ -46,4 +46,7 @@ type Request interface {
// DisableAccessLog 在当前请求中不使用访问日志
DisableAccessLog()
// DisableStat 在当前请求中停用统计
DisableStat()
}

View File

@@ -85,6 +85,10 @@ func (this *TestRequest) DisableAccessLog() {
}
func (this *TestRequest) DisableStat() {
}
func (this *TestRequest) ProcessResponseHeaders(headers http.Header, status int) {
}

View File

@@ -9,17 +9,20 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/filterconfigs"
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/runes"
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/TeaGo/utils/string"
stringutil "github.com/iwind/TeaGo/utils/string"
"net"
"reflect"
"regexp"
"sort"
"strings"
)
@@ -29,14 +32,14 @@ var singleParamRegexp = regexp.MustCompile(`^\${[\w.-]+}$`)
type Rule struct {
Id int64
Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"`
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"`
Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"`
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]any `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"`
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
@@ -50,12 +53,13 @@ type Rule struct {
ipRangeListValue *values.IPRangeList
stringValues []string
stringValueRunes [][]rune
ipList *values.StringList
floatValue float64
reg *re.Regexp
regCacheLife utils.CacheLife
reg *re.Regexp
cacheLife utils.CacheLife
}
func NewRule() *Rule {
@@ -77,7 +81,7 @@ func (this *Rule) Init() error {
this.floatValue = types.Float64(this.Value)
case RuleOperatorNeq:
this.floatValue = types.Float64(this.Value)
case RuleOperatorContainsAny, RuleOperatorContainsAll:
case RuleOperatorContainsAny, RuleOperatorContainsAll, RuleOperatorContainsAnyWord, RuleOperatorContainsAllWords, RuleOperatorNotContainsAnyWord:
this.stringValues = []string{}
if len(this.Value) > 0 {
var lines = strings.Split(this.Value, "\n")
@@ -91,9 +95,17 @@ func (this *Rule) Init() error {
}
}
}
if this.Operator == RuleOperatorContainsAnyWord || this.Operator == RuleOperatorContainsAllWords || this.Operator == RuleOperatorNotContainsAnyWord {
sort.Strings(this.stringValues)
}
this.stringValueRunes = [][]rune{}
for _, line := range this.stringValues {
this.stringValueRunes = append(this.stringValueRunes, []rune(line))
}
}
case RuleOperatorMatch:
v := this.Value
var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
@@ -106,7 +118,7 @@ func (this *Rule) Init() error {
}
this.reg = reg
case RuleOperatorNotMatch:
v := this.Value
var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
@@ -164,7 +176,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
} else {
var checkpoint = checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
@@ -174,7 +186,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
}
return nil
@@ -193,8 +205,8 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
if this.regCacheLife <= 0 || checkpoint.CacheLife() < this.regCacheLife {
this.regCacheLife = checkpoint.CacheLife()
if this.cacheLife <= 0 || checkpoint.CacheLife() < this.cacheLife {
this.cacheLife = checkpoint.CacheLife()
}
}
} else {
@@ -206,7 +218,7 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife()
this.cacheLife = checkpoint.CacheLife()
}
}
@@ -239,9 +251,9 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
return this.Test(value), hasRequestBody, nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
var pieces = strings.SplitN(varName, ".", 2)
var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
@@ -255,7 +267,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions, this.Id)
@@ -265,7 +277,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
})
if err != nil {
@@ -312,9 +324,9 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
return this.Test(value), hasRequestBody, nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
var pieces = strings.SplitN(varName, ".", 2)
var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
@@ -329,7 +341,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
} else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions, this.Id)
if hasCheckRequestBody {
@@ -338,7 +350,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
}
@@ -350,7 +362,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
} else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions, this.Id)
if hasCheckRequestBody {
@@ -359,7 +371,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil {
err = err1
}
return types.String(value1)
return this.stringifyValue(value1)
}
})
@@ -387,26 +399,37 @@ func (this *Rule) Test(value any) bool {
return types.Float64(value) != this.floatValue
case RuleOperatorEqString:
if this.IsCaseInsensitive {
return strings.EqualFold(types.String(value), this.Value)
return strings.EqualFold(this.stringifyValue(value), this.Value)
} else {
return types.String(value) == this.Value
return this.stringifyValue(value) == this.Value
}
case RuleOperatorNeqString:
if this.IsCaseInsensitive {
return !strings.EqualFold(types.String(value), this.Value)
return !strings.EqualFold(this.stringifyValue(value), this.Value)
} else {
return types.String(value) != this.Value
return this.stringifyValue(value) != this.Value
}
case RuleOperatorMatch, RuleOperatorWildcardMatch:
if value == nil {
return false
value = ""
}
// strings
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) {
if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return true
}
}
return false
}
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return true
}
}
@@ -416,19 +439,30 @@ func (this *Rule) Test(value any) bool {
// bytes
byteSlice, ok := value.([]byte)
if ok {
return utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife)
return utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
}
// string
return utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife)
return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch:
if value == nil {
return true
value = ""
}
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) {
if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return false
}
}
return true
}
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return false
}
}
@@ -438,17 +472,17 @@ func (this *Rule) Test(value any) bool {
// bytes
byteSlice, ok := value.([]byte)
if ok {
return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife)
return !utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
}
return !utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife)
return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorContains:
if types.IsSlice(value) {
_, isBytes := value.([]byte)
if !isBytes {
ok := false
var ok = false
lists.Each(value, func(k int, v any) {
if types.String(v) == this.Value {
if this.stringifyValue(v) == this.Value {
ok = true
}
})
@@ -456,17 +490,17 @@ func (this *Rule) Test(value any) bool {
}
}
if types.IsMap(value) {
lowerValue := ""
var lowerValue = ""
if this.IsCaseInsensitive {
lowerValue = strings.ToLower(this.Value)
}
for _, v := range maps.NewMap(value) {
if this.IsCaseInsensitive {
if strings.ToLower(types.String(v)) == lowerValue {
if strings.ToLower(this.stringifyValue(v)) == lowerValue {
return true
}
} else {
if types.String(v) == this.Value {
if this.stringifyValue(v) == this.Value {
return true
}
}
@@ -475,30 +509,44 @@ func (this *Rule) Test(value any) bool {
}
if this.IsCaseInsensitive {
return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
return strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else {
return strings.Contains(types.String(value), this.Value)
return strings.Contains(this.stringifyValue(value), this.Value)
}
case RuleOperatorNotContains:
if this.IsCaseInsensitive {
return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
return !strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else {
return !strings.Contains(types.String(value), this.Value)
return !strings.Contains(this.stringifyValue(value), this.Value)
}
case RuleOperatorPrefix:
if this.IsCaseInsensitive {
return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
var s = this.stringifyValue(value)
var sl = len(s)
var vl = len(this.Value)
if sl < vl {
return false
}
s = s[:vl]
return strings.HasPrefix(strings.ToLower(s), strings.ToLower(this.Value))
} else {
return strings.HasPrefix(types.String(value), this.Value)
return strings.HasPrefix(this.stringifyValue(value), this.Value)
}
case RuleOperatorSuffix:
if this.IsCaseInsensitive {
return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
var s = this.stringifyValue(value)
var sl = len(s)
var vl = len(this.Value)
if sl < vl {
return false
}
s = s[sl-vl:]
return strings.HasSuffix(strings.ToLower(s), strings.ToLower(this.Value))
} else {
return strings.HasSuffix(types.String(value), this.Value)
return strings.HasSuffix(this.stringifyValue(value), this.Value)
}
case RuleOperatorContainsAny:
var stringValue = types.String(value)
var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue)
}
@@ -511,7 +559,7 @@ func (this *Rule) Test(value any) bool {
}
return false
case RuleOperatorContainsAll:
var stringValue = types.String(value)
var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue)
}
@@ -524,31 +572,81 @@ func (this *Rule) Test(value any) bool {
return true
}
return false
case RuleOperatorContainsAnyWord:
return runes.ContainsAnyWordRunes(this.stringifyValue(value), this.stringValueRunes, this.IsCaseInsensitive)
case RuleOperatorContainsAllWords:
return runes.ContainsAllWords(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive)
case RuleOperatorNotContainsAnyWord:
return !runes.ContainsAnyWordRunes(this.stringifyValue(value), this.stringValueRunes, this.IsCaseInsensitive)
case RuleOperatorContainsSQLInjection:
if value == nil {
return false
}
switch xValue := value.(type) {
case []string:
for _, v := range xValue {
if injectionutils.DetectSQLInjectionCache(v, this.cacheLife) {
return true
}
}
return false
case [][]byte:
for _, v := range xValue {
if injectionutils.DetectSQLInjectionCache(string(v), this.cacheLife) {
return true
}
}
return false
default:
return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), this.cacheLife)
}
case RuleOperatorContainsXSS:
if value == nil {
return false
}
switch xValue := value.(type) {
case []string:
for _, v := range xValue {
if injectionutils.DetectXSSCache(v, this.cacheLife) {
return true
}
}
return false
case [][]byte:
for _, v := range xValue {
if injectionutils.DetectXSSCache(string(v), this.cacheLife) {
return true
}
}
return false
default:
return injectionutils.DetectXSSCache(this.stringifyValue(value), this.cacheLife)
}
case RuleOperatorContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value))
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive {
return bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data))
return bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else {
return bytes.Contains([]byte(types.String(value)), data)
return bytes.Contains([]byte(this.stringifyValue(value)), data)
}
case RuleOperatorNotContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value))
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive {
return !bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data))
return !bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else {
return !bytes.Contains([]byte(types.String(value)), data)
return !bytes.Contains([]byte(this.stringifyValue(value)), data)
}
case RuleOperatorHasKey:
if types.IsSlice(value) {
index := types.Int(this.Value)
var index = types.Int(this.Value)
if index < 0 {
return false
}
return reflect.ValueOf(value).Len() > index
} else if types.IsMap(value) {
m := maps.NewMap(value)
var m = maps.NewMap(value)
if this.IsCaseInsensitive {
lowerValue := strings.ToLower(this.Value)
var lowerValue = strings.ToLower(this.Value)
for k := range m {
if strings.ToLower(k) == lowerValue {
return true
@@ -567,9 +665,9 @@ func (this *Rule) Test(value any) bool {
return stringutil.VersionCompare(this.Value, types.String(value)) < 0
case RuleOperatorVersionRange:
if strings.Contains(this.Value, ",") {
versions := strings.SplitN(this.Value, ",", 2)
version1 := strings.TrimSpace(versions[0])
version2 := strings.TrimSpace(versions[1])
var versions = strings.SplitN(this.Value, ",", 2)
var version1 = strings.TrimSpace(versions[0])
var version2 = strings.TrimSpace(versions[1])
if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
return false
}
@@ -587,25 +685,25 @@ func (this *Rule) Test(value any) bool {
}
return this.isIP && ip.Equal(this.ipValue)
case RuleOperatorGtIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) > 0
case RuleOperatorGteIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
case RuleOperatorLtIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) < 0
case RuleOperatorLteIP:
ip := net.ParseIP(types.String(value))
var ip = net.ParseIP(types.String(value))
if ip == nil {
return false
}
@@ -624,7 +722,7 @@ func (this *Rule) Test(value any) bool {
if div == 0 {
return false
}
rem := types.Int64(pieces[1])
var rem = types.Int64(pieces[1])
return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
case RuleOperatorIPMod10:
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
@@ -737,3 +835,25 @@ func (this *Rule) execFilter(value any) any {
}
return value
}
func (this *Rule) stringifyValue(value any) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case []string:
return strings.Join(v, "")
case []byte:
return string(v)
case [][]byte:
var b = &bytes.Buffer{}
for _, vb := range v {
b.Write(vb)
}
return b.String()
default:
return types.String(v)
}
}

View File

@@ -74,7 +74,7 @@ func (this *RuleGroup) RemoveRuleSet(id int64) {
this.RuleSets = result
}
func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, set *RuleSet, err error) {
func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, resultSet *RuleSet, err error) {
if !this.hasRuleSets {
return
}
@@ -93,7 +93,7 @@ func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, hasRequestBod
return
}
func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, hasRequestBody bool, set *RuleSet, err error) {
func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, hasRequestBody bool, resultSet *RuleSet, err error) {
if !this.hasRuleSets {
return
}

Some files were not shown because too many files have changed in this diff Show More