Compare commits

..

83 Commits

Author SHA1 Message Date
刘祥超
6c2d488c37 版本号修改为1.3.3 2024-01-21 16:58:24 +08:00
刘祥超
d84b844e53 优化测试用例 2024-01-21 11:13:30 +08:00
刘祥超
095c381ae5 WAF允许动作默认跳过所有规则 2024-01-20 20:54:41 +08:00
刘祥超
7d11b3c63b WAF策略增加显示页面动作默认设置 2024-01-20 16:18:07 +08:00
刘祥超
45ba4fe5f1 修复部分内置页面没有<head>标签的问题 2024-01-20 10:14:15 +08:00
刘祥超
ac341da05b 优化代码 2024-01-19 09:27:42 +08:00
刘祥超
12b9c37095 增加edge-node config命令 2024-01-18 10:54:43 +08:00
刘祥超
125b25ea27 如果源站请求响应中没有Content-Type,则不设置Content-Type 2024-01-17 15:17:19 +08:00
刘祥超
6b1d595d58 XSS检测增加测试用例 2024-01-16 21:13:10 +08:00
刘祥超
f26b80e9c1 修改版本号为1.3.2.2 2024-01-16 20:59:25 +08:00
刘祥超
6e1bb43a9e WAF操作符增加“包含SQL注入-严格模式” 2024-01-16 20:41:52 +08:00
刘祥超
35e7ce1435 优化XSS检测的模式 2024-01-16 19:56:04 +08:00
刘祥超
5eb247b999 修复5秒盾验证可能受WAF影响不能工作的问题 2024-01-16 12:03:06 +08:00
刘祥超
c94e93859f 修复Websocket连接无法报告连接关闭的问题 2024-01-16 09:24:51 +08:00
刘祥超
d694319191 优化缓存错误相关代码 2024-01-15 21:00:20 +08:00
刘祥超
84e61f7765 优化代码 2024-01-15 10:31:42 +08:00
刘祥超
196f0612dc 版本号修改为1.3.2.1 2024-01-15 08:40:31 +08:00
刘祥超
4cd5e12686 网站设置增加HLS加密功能(商业版本) 2024-01-14 20:34:55 +08:00
刘祥超
aca128c19d 优化代码 2024-01-13 19:33:29 +08:00
刘祥超
5052d20fbd 字符编码设置增加“强制替换”选项;修复字符编码大写选项不起作用的问题 2024-01-13 16:29:32 +08:00
刘祥超
74790bea4a 提升UA解析性能(2-4倍) 2024-01-12 16:30:32 +08:00
刘祥超
e922c12611 修复缓存策略无法切换文件和内存的问题 2024-01-12 14:17:12 +08:00
刘祥超
cc632d557b 自定义页面跳转支持使用变量 2024-01-11 15:37:56 +08:00
刘祥超
a7dd101dbf 套餐可以设置带宽限制 2024-01-11 15:25:47 +08:00
刘祥超
1e1cd5a643 鉴权访问日志标签增加"auth:"前缀 2024-01-11 10:48:53 +08:00
刘祥超
6888ccc350 优化brotli相关测试用例 2024-01-10 15:33:49 +08:00
刘祥超
8d9a4b6d4f 修复HTTP/3下无法传递ServerAddr的问题 2024-01-10 11:46:24 +08:00
刘祥超
470b33e90f 相关CFLAGS增加优化级别-O2 2024-01-09 19:15:51 +08:00
刘祥超
74a559b1a0 优化WAF验证码输入框和字体样式 2024-01-09 17:11:18 +08:00
刘祥超
9db86b011a WAF人机识别验证后返回后自动跳转到目标页面 2024-01-09 17:00:26 +08:00
刘祥超
c94a51f47b 优化WAF验证码输入界面 2024-01-09 15:39:52 +08:00
刘祥超
2a9ec05d45 在开发模式下重启时不重载临时IP白名单 2024-01-09 11:05:51 +08:00
刘祥超
c9811d78a9 WAF操作符增加“包含XSS注入-严格模式” 2024-01-04 14:54:17 +08:00
刘祥超
65435ab32d 写缓存失败时,允许继续读取源站内容 2023-12-27 20:55:12 +08:00
刘祥超
614c7a4687 优化测试用例 2023-12-25 16:57:25 +08:00
刘祥超
ef541a2d8f 优化计数器性能 2023-12-25 16:41:07 +08:00
刘祥超
9141a1434e edge-node gc命令中的pause时间改为最近一次回收的pause时间 2023-12-25 12:43:13 +08:00
刘祥超
17af07cce0 edge-node gc命令增加耗时和gc pause时长 2023-12-25 09:24:20 +08:00
刘祥超
cfa57fac66 优化计数器相关测试用例 2023-12-24 16:08:57 +08:00
刘祥超
47523eaa73 优化计数器性能 2023-12-24 15:11:09 +08:00
刘祥超
27a24c6a8a 版本号修改为1.3.2 2023-12-24 11:14:45 +08:00
刘祥超
9bc2b1a651 WAF参数中增加“请求来源” 2023-12-24 10:03:24 +08:00
刘祥超
4f24b7f39c 增加Websocket连接数统计 2023-12-20 11:43:00 +08:00
刘祥超
4607a1f4e7 版本号修改为1.3.1.2 2023-12-18 08:51:22 +08:00
刘祥超
0f2068b161 优化TCP源站错误提示 2023-12-15 18:38:09 +08:00
刘祥超
c039691a71 缓存设置中可以设置缓存主域名,用来复用多域名下的缓存 2023-12-13 18:41:51 +08:00
刘祥超
930ee44065 根据系统环境调整WebP转换线程数 2023-12-12 09:55:18 +08:00
刘祥超
8a9aac7d72 优化代码 2023-12-11 20:35:48 +08:00
刘祥超
e50bbb962d WebP策略变化时只更新相关配置 2023-12-11 11:09:12 +08:00
刘祥超
9ff936d0c1 WebP转换质量转移到WebP策略配置 2023-12-11 10:17:17 +08:00
刘祥超
f53727b09c WebP转换限制为单线程,防止占用系统资源过高 2023-12-11 09:33:04 +08:00
刘祥超
525ce1f923 优化WAF XSS检测,减少对图片内容的误判 2023-12-10 19:40:29 +08:00
刘祥超
16e7cd800c WAF SQL注入检测和XSS注入检测自动进行URL解码 2023-12-10 16:52:54 +08:00
刘祥超
3f34bfc0b0 节点进程停止时,自动保存WAF临时白名单,并在进程重新启动后恢复 2023-12-10 15:41:31 +08:00
刘祥超
548cd1002b 增加WAF相关测试用例 2023-12-10 09:27:29 +08:00
刘祥超
3423865868 优化测试用例 2023-12-10 08:54:39 +08:00
刘祥超
037bc8e0de 优化WAF单词匹配性能 2023-12-09 19:19:29 +08:00
刘祥超
e03292de28 WAF规则模板中XSS注入检测规则使用“包含XSS注入”操作符替代以往的正则表达式 2023-12-09 17:00:21 +08:00
刘祥超
ee2565905e 优化WAF动作“显示网页”显示 2023-12-09 15:55:40 +08:00
刘祥超
05881b457d WAF规则模板中SQL注入规则使用“包含SQL注入”操作符替代以往的正则表达式 2023-12-09 15:28:07 +08:00
刘祥超
b116effc6c WAF SQL注入和XSS检测增加缓存/优化部分WAF相关测试用例 2023-12-09 11:46:50 +08:00
刘祥超
536efeeb9c 提升单词匹配性能 2023-12-09 10:06:07 +08:00
刘祥超
e8638e4bec WAF检查项增加“所有报头名称” 2023-12-08 15:39:23 +08:00
刘祥超
c9db722129 WAF增加“包含XSS注入”操作符 2023-12-08 10:15:18 +08:00
刘祥超
90de472bd5 增加测试用例 2023-12-07 20:47:25 +08:00
刘祥超
50c6c60abf WAF SQL注入检测时支持 (http|https):// 开头的URL 2023-12-07 20:38:06 +08:00
刘祥超
cc10372fe1 WAF增加“包含SQL注入”操作符 2023-12-07 20:25:35 +08:00
刘祥超
05c98a0656 修复一处单词错误 2023-12-07 12:14:04 +08:00
刘祥超
1a790fe391 优化代码 2023-12-07 12:07:06 +08:00
刘祥超
7dbd73cb59 优化WAF中前缀和后缀相关操作符性能 2023-12-07 12:05:08 +08:00
刘祥超
4dfa571547 WAF操作符增加包含任一单词、包含所有单词、不包含任一单词 2023-12-07 11:42:59 +08:00
刘祥超
9f77f62308 WAF checkpoint返回值支持[][]byte 2023-12-05 17:18:53 +08:00
刘祥超
facea1ed96 优化代码 2023-12-05 16:28:10 +08:00
刘祥超
e367814db3 内容压缩级别允许为0 2023-12-05 10:48:17 +08:00
刘祥超
3a15408c98 修复缓存命中率统计测试用例 2023-12-03 14:55:09 +08:00
刘祥超
c504b37118 WAF相关跳转不计入统计 2023-12-03 14:41:11 +08:00
刘祥超
74708dc02f 默认不启用内存分片管理 2023-12-03 14:26:51 +08:00
刘祥超
0c097498bb 优化链表相关代码 2023-12-03 11:27:47 +08:00
刘祥超
981c063eff 优化验证码性能 2023-11-30 17:25:41 +08:00
刘祥超
5e35c50113 页面优化增加例外URL和限制URL 2023-11-30 15:48:50 +08:00
刘祥超
e6c2869ff2 增加“极验-行为验”验证码集成支持 2023-11-29 17:00:06 +08:00
刘祥超
358bec2e9b WAF验证码验证后返回时判断是否已通过验证 2023-11-28 20:39:42 +08:00
刘祥超
1cd644f2eb 优化验证码加载方式,减少不必要的图片生成 2023-11-28 18:07:27 +08:00
154 changed files with 19271 additions and 1531 deletions

View File

@@ -10,7 +10,7 @@ function build() {
# for macOS users: precompiled gcc can be downloaded from https://github.com/messense/homebrew-macos-cross-toolchains
GCC_X86_64_DIR="/usr/local/gcc/x86_64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="//usr/local/gcc/aarch64-unknown-linux-gnu/bin"
GCC_ARM64_DIR="/usr/local/gcc/aarch64-unknown-linux-gnu/bin"
OS=${1}
ARCH=${2}

View File

@@ -6,4 +6,6 @@ if [ -z "$TAG" ]; then
TAG="community"
fi
go test -v ../... -tags=${TAG}
# reference: https://pkg.go.dev/cmd/go/internal/test
go clean -testcache
go test -timeout 10s -tags="${TAG}" -cover ../...

View File

@@ -16,6 +16,7 @@ import (
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"gopkg.in/yaml.v3"
"net"
"net/http"
_ "net/http/pprof"
@@ -30,7 +31,7 @@ func main() {
var app = apps.NewAppCmd().
Version(teaconst.Version).
Product(teaconst.ProductName).
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog|uninstall]").
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|config|pprof|accesslog|uninstall]").
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc|bandwidth|disk|cache.garbage]").
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove|ip.close] IP")
@@ -228,11 +229,18 @@ func main() {
})
app.On("gc", func() {
var sock = gosock.NewTmpSock(teaconst.ProcessName)
_, err := sock.Send(&gosock.Command{Code: "gc"})
reply, err := sock.Send(&gosock.Command{Code: "gc"})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
fmt.Println("ok")
if reply == nil {
fmt.Println("ok")
} else {
var paramMap = maps.NewMap(reply.Params)
var pauseMS = paramMap.GetFloat64("pauseMS")
var costMS = paramMap.GetFloat64("costMS")
fmt.Printf("ok, cost: %.4fms, pause: %.4fms", costMS, pauseMS)
}
}
})
app.On("ip.drop", func() {
@@ -517,6 +525,41 @@ func main() {
fmt.Println("[ERROR]" + params.GetString("error"))
}
})
app.On("config", func() {
var configString = os.Args[len(os.Args)-1]
if configString == "config" {
fmt.Println("Usage: edge-node config '\nrpc.endpoints: [\"...\"]\nnodeId: \"...\"\nsecret: \"...\"\n'")
return
}
var config = &configs.APIConfig{}
err := yaml.Unmarshal([]byte(configString), config)
if err != nil {
fmt.Println("[ERROR]decode config failed: " + err.Error())
return
}
err = config.Init()
if err != nil {
fmt.Println("[ERROR]validate config failed: " + err.Error())
return
}
// marshal again
configYAML, err := yaml.Marshal(config)
if err != nil {
fmt.Println("[ERROR]encode config failed: " + err.Error())
return
}
err = os.WriteFile(Tea.Root + "/configs/api_node.yaml", configYAML, 0666)
if err != nil {
fmt.Println("[ERROR]write config failed: " + err.Error())
return
}
fmt.Println("success")
})
app.Run(func() {
var node = nodes.NewNode()
node.Start()

View File

@@ -34,16 +34,14 @@ func CanIgnoreErr(err error) bool {
if err == nil {
return true
}
if err == ErrFileIsWriting ||
err == ErrEntityTooLarge ||
err == ErrWritingUnavailable ||
err == ErrWritingQueueFull ||
err == ErrServerIsBusy {
if errors.Is(err, ErrFileIsWriting) ||
errors.Is(err, ErrEntityTooLarge) ||
errors.Is(err, ErrWritingUnavailable) ||
errors.Is(err, ErrWritingQueueFull) ||
errors.Is(err, ErrServerIsBusy) {
return true
}
_, ok := err.(*CapacityError)
if ok {
return true
}
return false
var capacityErr *CapacityError
return errors.As(err, &capacityErr)
}

View File

@@ -1,16 +1,24 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package caches
package caches_test
import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestCanIgnoreErr(t *testing.T) {
a := assert.NewAssertion(t)
var a = assert.NewAssertion(t)
a.IsTrue(CanIgnoreErr(ErrFileIsWriting))
a.IsTrue(CanIgnoreErr(NewCapacityError("over capcity")))
a.IsFalse(CanIgnoreErr(ErrNotFound))
a.IsTrue(caches.CanIgnoreErr(caches.ErrFileIsWriting))
a.IsTrue(caches.CanIgnoreErr(fmt.Errorf("error: %w", caches.ErrFileIsWriting)))
a.IsTrue(errors.Is(fmt.Errorf("error: %w", caches.ErrFileIsWriting), caches.ErrFileIsWriting))
a.IsTrue(errors.Is(caches.ErrFileIsWriting, caches.ErrFileIsWriting))
a.IsTrue(caches.CanIgnoreErr(caches.NewCapacityError("over capacity")))
a.IsTrue(caches.CanIgnoreErr(fmt.Errorf("error: %w", caches.NewCapacityError("over capacity"))))
a.IsFalse(caches.CanIgnoreErr(caches.ErrNotFound))
a.IsFalse(caches.CanIgnoreErr(errors.New("test error")))
}

View File

@@ -4,6 +4,7 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
@@ -18,7 +19,11 @@ func TestItems_Memory(t *testing.T) {
var memory1 = stat.HeapInuse
var items = []*caches.Item{}
for i := 0; i < 10_000_000; i++ {
var count = 100
if testutils.IsSingleTesting() {
count = 10_000_000
}
for i := 0; i < count; i++ {
items = append(items, &caches.Item{
Key: types.String(i),
})
@@ -33,7 +38,9 @@ func TestItems_Memory(t *testing.T) {
var memory3 = stat.HeapInuse
t.Log(memory2, memory3, (memory3-memory2)/1024/1024, "M")
time.Sleep(1 * time.Second)
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
}
func TestItems_Memory2(t *testing.T) {
@@ -42,7 +49,12 @@ func TestItems_Memory2(t *testing.T) {
var memory1 = stat.HeapInuse
var items = map[int32]map[string]zero.Zero{}
for i := 0; i < 10_000_000; i++ {
var count = 100
if testutils.IsSingleTesting() {
count = 10_000_000
}
for i := 0; i < count; i++ {
var week = int32((time.Now().Unix() - int64(86400*rands.Int(0, 300))) / (86400 * 7))
m, ok := items[week]
if !ok {
@@ -57,7 +69,9 @@ func TestItems_Memory2(t *testing.T) {
t.Log(memory1, memory2, (memory2-memory1)/1024/1024, "M")
time.Sleep(1 * time.Second)
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
for w, i := range items {
t.Log(w, len(i))
}

View File

@@ -4,6 +4,7 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/assert"
@@ -86,6 +87,10 @@ func TestFileListHashMap_BigInt(t *testing.T) {
}
func TestFileListHashMap_Load(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {

View File

@@ -17,6 +17,10 @@ import (
)
func TestFileList_Init(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
@@ -34,6 +38,10 @@ func TestFileList_Init(t *testing.T) {
}
func TestFileList_Add(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
@@ -107,6 +115,10 @@ func TestFileList_Add_Many(t *testing.T) {
}
func TestFileList_Exist(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
_ = list.Close()
@@ -143,6 +155,10 @@ func TestFileList_Exist(t *testing.T) {
}
func TestFileList_Exist_Many_DB(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
// 测试在多个数据库下的性能
var listSlice = []caches.ListInterface{}
for i := 1; i <= 10; i++ {
@@ -202,6 +218,10 @@ func TestFileList_Exist_Many_DB(t *testing.T) {
}
func TestFileList_CleanPrefix(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
@@ -222,6 +242,10 @@ func TestFileList_CleanPrefix(t *testing.T) {
}
func TestFileList_Remove(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
_ = list.Close()
@@ -246,6 +270,10 @@ func TestFileList_Remove(t *testing.T) {
}
func TestFileList_Purge(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
@@ -270,6 +298,10 @@ func TestFileList_Purge(t *testing.T) {
}
func TestFileList_PurgeLFU(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
@@ -294,6 +326,10 @@ func TestFileList_PurgeLFU(t *testing.T) {
}
func TestFileList_Stat(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
@@ -313,6 +349,10 @@ func TestFileList_Stat(t *testing.T) {
}
func TestFileList_Count(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data")
defer func() {
@@ -333,6 +373,10 @@ func TestFileList_Count(t *testing.T) {
}
func TestFileList_CleanAll(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data")
defer func() {
@@ -352,6 +396,10 @@ func TestFileList_CleanAll(t *testing.T) {
}
func TestFileList_UpgradeV3(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p43").(*caches.FileList)
defer func() {
@@ -376,6 +424,10 @@ func TestFileList_UpgradeV3(t *testing.T) {
}
func BenchmarkFileList_Exist(b *testing.B) {
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {

View File

@@ -400,7 +400,19 @@ func (this *MemoryList) IncreaseHit(hash string) error {
return nil
}
func (this *MemoryList) print(t *testing.T) {
func (this *MemoryList) Prefixes() []string {
return this.prefixes
}
func (this *MemoryList) ItemMaps() map[string]map[string]*Item {
return this.itemMaps
}
func (this *MemoryList) PurgeIndex() int {
return this.purgeIndex
}
func (this *MemoryList) Print(t *testing.T) {
this.locker.Lock()
for _, itemMap := range this.itemMaps {
if len(itemMap) > 0 {

View File

@@ -1,12 +1,14 @@
package caches
package caches_test
import (
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/cespare/xxhash"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"math/rand"
"sort"
"strconv"
@@ -15,65 +17,65 @@ import (
)
func TestMemoryList_Add(t *testing.T) {
list := NewMemoryList().(*MemoryList)
list := caches.NewMemoryList().(*caches.MemoryList)
_ = list.Init()
_ = list.Add("a", &Item{
_ = list.Add("a", &caches.Item{
Key: "a1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("b", &Item{
_ = list.Add("b", &caches.Item{
Key: "b1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("123456", &Item{
_ = list.Add("123456", &caches.Item{
Key: "c1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
t.Log(list.prefixes)
logs.PrintAsJSON(list.itemMaps, t)
t.Log(list.Prefixes())
logs.PrintAsJSON(list.ItemMaps(), t)
t.Log(list.Count())
}
func TestMemoryList_Remove(t *testing.T) {
list := NewMemoryList().(*MemoryList)
list := caches.NewMemoryList().(*caches.MemoryList)
_ = list.Init()
_ = list.Add("a", &Item{
_ = list.Add("a", &caches.Item{
Key: "a1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("b", &Item{
_ = list.Add("b", &caches.Item{
Key: "b1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Remove("b")
list.print(t)
list.Print(t)
t.Log(list.Count())
}
func TestMemoryList_Purge(t *testing.T) {
list := NewMemoryList().(*MemoryList)
list := caches.NewMemoryList().(*caches.MemoryList)
_ = list.Init()
_ = list.Add("a", &Item{
_ = list.Add("a", &caches.Item{
Key: "a1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("b", &Item{
_ = list.Add("b", &caches.Item{
Key: "b1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("c", &Item{
_ = list.Add("c", &caches.Item{
Key: "c1",
ExpiredAt: time.Now().Unix() - 3600,
HeaderSize: 1024,
})
_ = list.Add("d", &Item{
_ = list.Add("d", &caches.Item{
Key: "d1",
ExpiredAt: time.Now().Unix() - 2,
HeaderSize: 1024,
@@ -82,25 +84,30 @@ func TestMemoryList_Purge(t *testing.T) {
t.Log("delete:", hash)
return nil
})
list.print(t)
list.Print(t)
for i := 0; i < 1000; i++ {
_, _ = list.Purge(100, func(hash string) error {
t.Log("delete:", hash)
return nil
})
t.Log(list.purgeIndex)
t.Log(list.PurgeIndex())
}
t.Log(list.Count())
}
func TestMemoryList_Purge_Large_List(t *testing.T) {
list := NewMemoryList().(*MemoryList)
var list = caches.NewMemoryList().(*caches.MemoryList)
_ = list.Init()
for i := 0; i < 1_000_000; i++ {
_ = list.Add("a"+strconv.Itoa(i), &Item{
var count = 100
if testutils.IsSingleTesting() {
count = 1_000_000
}
for i := 0; i < count; i++ {
_ = list.Add("a"+strconv.Itoa(i), &caches.Item{
Key: "a" + strconv.Itoa(i),
ExpiredAt: time.Now().Unix() + int64(rands.Int(0, 24*3600)),
HeaderSize: 1024,
@@ -113,43 +120,46 @@ func TestMemoryList_Purge_Large_List(t *testing.T) {
}
func TestMemoryList_Stat(t *testing.T) {
list := NewMemoryList()
list := caches.NewMemoryList()
_ = list.Init()
_ = list.Add("a", &Item{
_ = list.Add("a", &caches.Item{
Key: "a1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("b", &Item{
_ = list.Add("b", &caches.Item{
Key: "b1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
_ = list.Add("c", &Item{
_ = list.Add("c", &caches.Item{
Key: "c1",
ExpiredAt: time.Now().Unix(),
HeaderSize: 1024,
})
_ = list.Add("d", &Item{
_ = list.Add("d", &caches.Item{
Key: "d1",
ExpiredAt: time.Now().Unix() - 2,
HeaderSize: 1024,
})
result, _ := list.Stat(func(hash string) bool {
// 随机测试
rand.Seed(time.Now().UnixNano())
return rand.Int()%2 == 0
})
t.Log(result)
}
func TestMemoryList_CleanPrefix(t *testing.T) {
list := NewMemoryList()
list := caches.NewMemoryList()
_ = list.Init()
before := time.Now()
for i := 0; i < 1_000_000; i++ {
var count = 100
if testutils.IsSingleTesting() {
count = 1_000_000
}
for i := 0; i < count; i++ {
key := "https://www.teaos.cn/hello/" + strconv.Itoa(i/10000) + "/" + strconv.Itoa(i) + ".html"
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &caches.Item{
Key: key,
ExpiredAt: time.Now().Unix() + 3600,
BodySize: 0,
@@ -174,7 +184,12 @@ func TestMemoryList_CleanPrefix(t *testing.T) {
func TestMapRandomDelete(t *testing.T) {
var countMap = map[int]int{} // k => count
for j := 0; j < 1_000_000; j++ {
var count = 1000
if testutils.IsSingleTesting() {
count = 1_000_000
}
for j := 0; j < count; j++ {
var m = map[int]bool{}
for i := 0; i < 100; i++ {
m[i] = true
@@ -203,18 +218,18 @@ func TestMapRandomDelete(t *testing.T) {
}
func TestMemoryList_PurgeLFU(t *testing.T) {
var list = NewMemoryList().(*MemoryList)
var list = caches.NewMemoryList().(*caches.MemoryList)
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
_ = list.Add("1", &Item{})
_ = list.Add("2", &Item{})
_ = list.Add("3", &Item{})
_ = list.Add("4", &Item{})
_ = list.Add("5", &Item{})
_ = list.Add("1", &caches.Item{})
_ = list.Add("2", &caches.Item{})
_ = list.Add("3", &caches.Item{})
_ = list.Add("4", &caches.Item{})
_ = list.Add("5", &caches.Item{})
//_ = list.IncreaseHit("1")
//_ = list.IncreaseHit("2")
@@ -245,29 +260,32 @@ func TestMemoryList_PurgeLFU(t *testing.T) {
}
func TestMemoryList_CleanAll(t *testing.T) {
var list = NewMemoryList().(*MemoryList)
_ = list.Add("a", &Item{})
var list = caches.NewMemoryList().(*caches.MemoryList)
_ = list.Add("a", &caches.Item{})
_ = list.CleanAll()
logs.PrintAsJSON(list.itemMaps, t)
logs.PrintAsJSON(list.ItemMaps(), t)
t.Log(list.Count())
}
func TestMemoryList_GC(t *testing.T) {
list := NewMemoryList().(*MemoryList)
if !testutils.IsSingleTesting() {
return
}
list := caches.NewMemoryList().(*caches.MemoryList)
_ = list.Init()
for i := 0; i < 1_000_000; i++ {
key := "https://www.teaos.cn/hello" + strconv.Itoa(i/100000) + "/" + strconv.Itoa(i) + ".html"
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &caches.Item{
Key: key,
ExpiredAt: 0,
BodySize: 0,
HeaderSize: 0,
})
}
time.Sleep(10 * time.Second)
t.Log("clean...", len(list.itemMaps))
t.Log("clean...", len(list.ItemMaps()))
_ = list.CleanAll()
t.Log("cleanAll...", len(list.itemMaps))
t.Log("cleanAll...", len(list.ItemMaps()))
before := time.Now()
//runtime.GC()
t.Log("gc cost:", time.Since(before).Seconds()*1000, "ms")
@@ -280,3 +298,30 @@ func TestMemoryList_GC(t *testing.T) {
time.Sleep(30 * time.Minute)
}
}
func BenchmarkMemoryList(b *testing.B) {
var list = caches.NewMemoryList()
err := list.Init()
if err != nil {
b.Fatal(err)
}
for i := 0; i < 1_000_000; i++ {
_ = list.Add(stringutil.Md5(types.String(i)), &caches.Item{
Key: "a1",
ExpiredAt: time.Now().Unix() + 3600,
HeaderSize: 1024,
})
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = list.Exist(types.String("a" + types.String(rands.Int(1, 10000))))
_ = list.Add("a"+types.String(rands.Int(1, 100000)), &caches.Item{})
_, _ = list.Purge(1000, func(hash string) error {
return nil
})
}
})
}

View File

@@ -15,6 +15,7 @@ import (
)
const (
enableFragmentPool = false
minMemoryFragmentPoolItemSize = 8 << 10
maxMemoryFragmentPoolItemSize = 128 << 20
maxItemsInMemoryFragmentPoolBucket = 1024

View File

@@ -111,6 +111,15 @@ func (this *FileStorage) Policy() *serverconfigs.HTTPCachePolicy {
// CanUpdatePolicy 检查策略是否可以更新
func (this *FileStorage) CanUpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy) bool {
if newPolicy == nil {
return false
}
// 检查类型
if newPolicy.Type != serverconfigs.CachePolicyStorageFile {
return false
}
// 检查路径是否有变化
oldOptionsJSON, err := json.Marshal(this.policy.Options)
if err != nil {
@@ -455,7 +464,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
_, ok := sharedWritingFileKeyMap[key]
if ok {
sharedWritingFileKeyLocker.Unlock()
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w(001)", ErrFileIsWriting)
}
if !isFlushing && !fsutils.WriteReady() {
@@ -496,7 +505,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
// 检查两次写入缓存的时间是否过于相近,分片内容不受此限制
if err == nil && !isPartial && time.Since(stat.ModTime()) <= 1*time.Second {
// 防止并发连续写入
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w(002)", ErrFileIsWriting)
}
// 构造文件名
@@ -597,7 +606,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
err = syscall.Flock(int(writer.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
removeOnFailure = false
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w (003)", ErrFileIsWriting)
}
var metaBodySize int64 = -1

View File

@@ -19,6 +19,10 @@ import (
)
func TestFileStorage_Init(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -49,6 +53,10 @@ func TestFileStorage_Init(t *testing.T) {
}
func TestFileStorage_OpenWriter(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -96,6 +104,10 @@ func TestFileStorage_OpenWriter(t *testing.T) {
}
func TestFileStorage_OpenWriter_Partial(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 2,
IsOn: true,
@@ -134,6 +146,10 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
}
func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -202,6 +218,10 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
}
func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -231,7 +251,7 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200, -1, -1, -1, false)
if err != nil {
if err != ErrFileIsWriting {
if errors.Is(err, ErrFileIsWriting) {
t.Error(err)
return
}
@@ -260,6 +280,10 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
}
func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -289,7 +313,7 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200, -1, -1, -1, false)
if err != nil {
if err != ErrFileIsWriting {
if errors.Is(err, ErrFileIsWriting) {
t.Error(err)
return
}
@@ -319,6 +343,10 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
}
func TestFileStorage_Read(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -358,6 +386,10 @@ func TestFileStorage_Read(t *testing.T) {
}
func TestFileStorage_Read_HTTP_Response(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -414,6 +446,10 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) {
}
func TestFileStorage_Read_NotFound(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -450,6 +486,10 @@ func TestFileStorage_Read_NotFound(t *testing.T) {
}
func TestFileStorage_Delete(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -472,6 +512,10 @@ func TestFileStorage_Delete(t *testing.T) {
}
func TestFileStorage_Stat(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -500,6 +544,10 @@ func TestFileStorage_Stat(t *testing.T) {
}
func TestFileStorage_CleanAll(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -534,6 +582,10 @@ func TestFileStorage_CleanAll(t *testing.T) {
}
func TestFileStorage_Stop(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -552,6 +604,10 @@ func TestFileStorage_Stop(t *testing.T) {
}
func TestFileStorage_DecodeFile(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
@@ -571,6 +627,10 @@ func TestFileStorage_DecodeFile(t *testing.T) {
}
func TestFileStorage_RemoveCacheFile(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewFileStorage(nil)
defer storage.Stop()

View File

@@ -1,6 +1,7 @@
package caches
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -161,7 +162,7 @@ func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, h
// TODO 内存缓存暂时不支持分块内容存储
if isPartial {
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w (004)", ErrFileIsWriting)
}
return this.openWriter(key, expiredAt, status, headerSize, bodySize, maxSize, true)
}
@@ -187,7 +188,7 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, h
var isWriting = false
_, ok := this.writingKeyMap[key]
if ok {
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w (005)", ErrFileIsWriting)
}
this.writingKeyMap[key] = zero.New()
defer func() {
@@ -208,7 +209,7 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, h
_ = this.list.Remove(hashString)
item = nil
} else {
return nil, ErrFileIsWriting
return nil, fmt.Errorf("%w (006)", ErrFileIsWriting)
}
}
@@ -366,7 +367,7 @@ func (this *MemoryStorage) UpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy
// CanUpdatePolicy 检查策略是否可以更新
func (this *MemoryStorage) CanUpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy) bool {
return true
return newPolicy != nil && newPolicy.Type == serverconfigs.CachePolicyStorageMemory
}
// AddToList 将缓存添加到列表
@@ -517,7 +518,7 @@ func (this *MemoryStorage) flushItem(key string) {
_ = this.Delete(key)
// 重用内存,前提是确保内存不再被引用
if ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
if enableFragmentPool && ok && item.IsDone && !item.isReferring && len(item.BodyValue) > 0 {
SharedFragmentMemoryPool.Put(item.BodyValue)
}
}()

View File

@@ -3,6 +3,7 @@ package caches
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
"runtime"
@@ -271,6 +272,10 @@ func TestMemoryStorage_Purge(t *testing.T) {
}
func TestMemoryStorage_Expire(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
MemoryAutoPurgeInterval: 5,
}, nil)

View File

@@ -32,7 +32,9 @@ func NewMemoryWriter(memoryStorage *MemoryStorage, key string, expiredAt int64,
ModifiedAt: fasttime.Now().Unix(),
Status: status,
}
if expectedBodySize > 0 && expectedBodySize <= maxMemoryFragmentPoolItemSize {
if enableFragmentPool &&
expectedBodySize > 0 &&
expectedBodySize <= maxMemoryFragmentPoolItemSize {
bodyBytes, ok := SharedFragmentMemoryPool.Get(expectedBodySize) // try to reuse memory
if ok {
valueItem.BodyValue = bodyBytes
@@ -123,7 +125,6 @@ func (this *MemoryWriter) Close() error {
// 需要在Locker之外
defer this.once.Do(func() {
this.endFunc(this.item)
this.item = nil // free memory
})
if this.item == nil {
@@ -162,13 +163,13 @@ func (this *MemoryWriter) Discard() error {
// 需要在Locker之外
defer this.once.Do(func() {
this.endFunc(this.item)
this.item = nil // free memory
})
this.storage.locker.Lock()
delete(this.storage.valuesMap, this.hash)
if this.item != nil &&
if enableFragmentPool &&
this.item != nil &&
!this.item.isReferring &&
cap(this.item.BodyValue) >= minMemoryFragmentPoolItemSize {
SharedFragmentMemoryPool.Put(this.item.BodyValue)

View File

@@ -6,6 +6,7 @@ import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"io"
"os"
"testing"
)
@@ -49,3 +50,45 @@ func TestBrotliReader(t *testing.T) {
}
}
}
func BenchmarkBrotliReader(b *testing.B) {
data, err := os.ReadFile("./reader_brotli.go")
if err != nil {
b.Fatal(err)
}
var buf = bytes.NewBuffer([]byte{})
writer, err := compressions.NewBrotliWriter(buf, 5)
if err != nil {
b.Fatal(err)
}
_, err = writer.Write(data)
err = writer.Close()
if err != nil {
b.Fatal(err)
}
var compressedData = buf.Bytes()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
reader, readerErr := compressions.NewBrotliReader(bytes.NewBuffer(compressedData))
if readerErr != nil {
b.Fatal(readerErr)
}
var readBuf = make([]byte, 1024)
for {
_, readErr := reader.Read(readBuf)
if readErr != nil {
if readErr != io.EOF {
b.Fatal(readErr)
}
break
}
}
closeErr := reader.Close()
if closeErr != nil {
b.Fatal(closeErr)
}
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"io"
"net/http"
)
type ContentEncoding = string
@@ -58,3 +59,32 @@ func NewWriter(writer io.Writer, compressType serverconfigs.HTTPCompressionType,
}
return nil, errors.New("invalid compression type '" + compressType + "'")
}
// SupportEncoding 检查是否支持某个编码
func SupportEncoding(encoding string) bool {
return encoding == ContentEncodingBr ||
encoding == ContentEncodingGzip ||
encoding == ContentEncodingDeflate ||
encoding == ContentEncodingZSTD
}
// WrapHTTPResponse 包装http.Response对象
func WrapHTTPResponse(resp *http.Response) {
if resp == nil {
return
}
var contentEncoding = resp.Header.Get("Content-Encoding")
if len(contentEncoding) == 0 || !SupportEncoding(contentEncoding) {
return
}
reader, err := NewReader(resp.Body, contentEncoding)
if err != nil {
// unable to decode, we ignore the error
return
}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.Body = reader
}

View File

@@ -19,6 +19,10 @@ func NewZSTDWriter(writer io.Writer, level int) (Writer, error) {
}
func newZSTDWriter(writer io.Writer, level int) (Writer, error) {
if level < 0 {
level = 0
}
var zstdLevel = zstd.EncoderLevelFromZstd(level)
zstdWriter, err := zstd.NewWriter(writer, zstd.WithEncoderLevel(zstdLevel))

View File

@@ -9,6 +9,24 @@ import (
"testing"
)
func TestNewZSTDWriter_Level0(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 0)
if err != nil {
t.Fatal(err)
}
var originData = []byte(strings.Repeat("Hello", 1024))
_, err = writer.Write(originData)
if err != nil {
t.Fatal(err)
}
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log("origin data:", len(originData), "result:", buf.Len())
}
func TestNewZSTDWriter(t *testing.T) {
var buf = &bytes.Buffer{}
writer, err := compressions.NewZSTDWriter(buf, 10)

View File

@@ -12,8 +12,8 @@ const oldConfigFileName = "api.yaml"
type APIConfig struct {
OldRPC struct {
Endpoints []string `yaml:"endpoints" json:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate" json:"disableUpdate"`
Endpoints []string `yaml:"endpoints,omitempty" json:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate,omitempty" json:"disableUpdate"`
} `yaml:"rpc,omitempty" json:"rpc"`
RPCEndpoints []string `yaml:"rpc.endpoints,flow" json:"rpc.endpoints"`

View File

@@ -4,11 +4,16 @@ package configs_test
import (
"github.com/TeaOSLab/EdgeNode/internal/configs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"gopkg.in/yaml.v3"
"testing"
)
func TestLoadClusterConfig(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
config, err := configs.LoadClusterConfig()
if err != nil {
t.Fatal(err)

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.3.1"
Version = "1.3.3"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -3,11 +3,16 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestHTTPAPIAction_AddItem(t *testing.T) {
action := NewHTTPAPIAction()
if !testutils.IsSingleTesting() {
return
}
var action = NewHTTPAPIAction()
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
URL: "http://127.0.0.1:2345/post",
TimeoutSeconds: 0,
@@ -24,7 +29,11 @@ func TestHTTPAPIAction_AddItem(t *testing.T) {
}
func TestHTTPAPIAction_DeleteItem(t *testing.T) {
action := NewHTTPAPIAction()
if !testutils.IsSingleTesting() {
return
}
var action = NewHTTPAPIAction()
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
URL: "http://127.0.0.1:2345/post",
TimeoutSeconds: 0,

View File

@@ -4,13 +4,19 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/maps"
"testing"
"time"
)
func TestIPSetAction_Init(t *testing.T) {
action := iplibrary.NewIPSetAction()
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
err := action.Init(&firewallconfigs.FirewallActionConfig{
Params: maps.Map{
"path": "/usr/bin/iptables",
@@ -25,6 +31,11 @@ func TestIPSetAction_Init(t *testing.T) {
}
func TestIPSetAction_AddItem(t *testing.T) {
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
action.SetConfig(&firewallconfigs.FirewallActionIPSetConfig{
Path: "/usr/bin/iptables",
@@ -84,7 +95,12 @@ func TestIPSetAction_AddItem(t *testing.T) {
}
func TestIPSetAction_DeleteItem(t *testing.T) {
action := iplibrary.NewIPSetAction()
_, lookupErr := executils.LookPath("firewalld")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
err := action.Init(&firewallconfigs.FirewallActionConfig{
Params: maps.Map{
"path": "/usr/bin/firewalld",

View File

@@ -3,12 +3,18 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"testing"
"time"
)
func TestIPTablesAction_AddItem(t *testing.T) {
action := NewIPTablesAction()
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = NewIPTablesAction()
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
Path: "/usr/bin/iptables",
}
@@ -40,7 +46,12 @@ func TestIPTablesAction_AddItem(t *testing.T) {
}
func TestIPTablesAction_DeleteItem(t *testing.T) {
action := NewIPTablesAction()
_, lookupErr := executils.LookPath("firewalld")
if lookupErr != nil {
return
}
var action = NewIPTablesAction()
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
Path: "/usr/bin/firewalld",
}

View File

@@ -7,7 +7,7 @@ import (
)
func TestActionManager_UpdateActions(t *testing.T) {
manager := NewActionManager()
var manager = NewActionManager()
manager.UpdateActions([]*firewallconfigs.FirewallActionConfig{
{
Id: 1,

View File

@@ -3,11 +3,16 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
func TestScriptAction_AddItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
action := NewScriptAction()
action.config = &firewallconfigs.FirewallActionScriptConfig{
Path: "/tmp/ip-item.sh",
@@ -27,6 +32,10 @@ func TestScriptAction_AddItem(t *testing.T) {
}
func TestScriptAction_DeleteItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
action := NewScriptAction()
action.config = &firewallconfigs.FirewallActionScriptConfig{
Path: "/tmp/ip-item.sh",

View File

@@ -2,6 +2,7 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"runtime"
"testing"
@@ -75,8 +76,14 @@ func TestIPItem_Contains(t *testing.T) {
}
func TestIPItem_Memory(t *testing.T) {
var isSingleTest = testutils.IsSingleTesting()
var list = NewIPList()
for i := 0; i < 2_000_000; i ++ {
var count = 100
if isSingleTest {
count = 2_000_000
}
for i := 0; i < count; i++ {
list.Add(&IPItem{
Type: "ip",
Id: uint64(i),
@@ -87,7 +94,9 @@ func TestIPItem_Memory(t *testing.T) {
})
}
t.Log("waiting")
time.Sleep(10 * time.Second)
if isSingleTest {
time.Sleep(10 * time.Second)
}
}
func BenchmarkIPItem_Contains(b *testing.B) {
@@ -105,4 +114,3 @@ func BenchmarkIPItem_Contains(b *testing.B) {
}
}
}

View File

@@ -16,6 +16,9 @@ func TestIPListDB_AddItem(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = db.AddItem(&pb.IPItem{
Id: 1,
@@ -60,6 +63,9 @@ func TestIPListDB_ReadItems(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
defer func() {
_ = db.Close()
@@ -77,6 +83,9 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
t.Log(db.ReadMaxVersion())
}
@@ -85,6 +94,10 @@ func TestIPListDB_UpdateMaxVersion(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = db.UpdateMaxVersion(1027)
if err != nil {
t.Fatal(err)

View File

@@ -3,12 +3,17 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
func TestIPIsAllowed(t *testing.T) {
manager := NewIPListManager()
if !testutils.IsSingleTesting() {
return
}
var manager = NewIPListManager()
manager.init()
var before = time.Now()

View File

@@ -2,13 +2,18 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/logs"
"testing"
"time"
)
func TestIPListManager_init(t *testing.T) {
manager := NewIPListManager()
if !testutils.IsSingleTesting() {
return
}
var manager = NewIPListManager()
manager.init()
t.Log(manager.listMap)
t.Log(SharedServerListManager.blackMap)
@@ -16,7 +21,11 @@ func TestIPListManager_init(t *testing.T) {
}
func TestIPListManager_check(t *testing.T) {
manager := NewIPListManager()
if !testutils.IsSingleTesting() {
return
}
var manager = NewIPListManager()
manager.init()
var before = time.Now()
@@ -28,7 +37,11 @@ func TestIPListManager_check(t *testing.T) {
}
func TestIPListManager_loop(t *testing.T) {
manager := NewIPListManager()
if !testutils.IsSingleTesting() {
return
}
var manager = NewIPListManager()
manager.Start()
err := manager.loop()
if err != nil {

View File

@@ -5,6 +5,7 @@ package monitor
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc/status"
@@ -12,6 +13,10 @@ import (
)
func TestValueQueue_RPC(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)

View File

@@ -90,13 +90,13 @@ func (this *APIStream) loop() error {
break
}
message, err := nodeStream.Recv()
if err != nil {
message, streamErr := nodeStream.Recv()
if streamErr != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return err
return streamErr
}
// 处理消息

View File

@@ -1,8 +1,15 @@
package nodes
import "testing"
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestAPIStream_Start(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
apiStream := NewAPIStream()
apiStream.Start()
}

View File

@@ -110,6 +110,10 @@ func TestHTTPAccessLogQueue_Push2(t *testing.T) {
}
func TestHTTPAccessLogQueue_Memory(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
testutils.StartMemoryStats(t)
debug.SetGCPercent(10)

View File

@@ -3,13 +3,14 @@ package nodes
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"runtime"
"testing"
"time"
)
func TestHTTPClientPool_Client(t *testing.T) {
pool := NewHTTPClientPool()
var pool = NewHTTPClientPool()
{
var origin = &serverconfigs.OriginConfig{
@@ -54,7 +55,10 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
for i := 0; i < 10; i++ {
t.Log("get", i)
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
time.Sleep(1 * time.Second)
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
}
}

View File

@@ -85,6 +85,8 @@ type HTTPRequest struct {
isAttack bool // 是否是攻击请求
requestBodyData []byte // 读取的Body内容
isWebsocketResponse bool // 是否为Websocket响应非请求
// WAF相关
firewallPolicyId int64
firewallRuleGroupId int64
@@ -100,6 +102,8 @@ type HTTPRequest struct {
disableLog bool // 是否在当前请求中关闭Log
forceLog bool // 是否强制记录日志
isHijacked bool
// script相关操作
isDone bool
}
@@ -191,17 +195,35 @@ func (this *HTTPRequest) Do() {
}
// 套餐
if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() {
this.doPlanExpires()
this.doEnd()
return
if this.ReqServer.UserPlan != nil {
if this.doPlanBefore() {
this.doEnd()
return
}
}
// 流量限制
if this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
this.doTrafficLimit()
this.doEnd()
return
if this.doTrafficLimit(this.ReqServer.TrafficLimitStatus) {
this.doEnd()
return
}
}
// UAM
var uamIsCalled = false
if !this.isHealthCheck {
if this.web.UAM == nil && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
this.web.UAM = this.ReqServer.UAM
}
if this.web.UAM != nil && this.web.UAM.IsOn && this.isUAMRequest() {
uamIsCalled = true
if this.doUAM() {
this.doEnd()
return
}
}
}
// WAF
@@ -213,16 +235,8 @@ func (this *HTTPRequest) Do() {
}
// UAM
if !this.isHealthCheck {
if this.web.UAM != nil {
if this.web.UAM.IsOn {
if this.doUAM() {
this.doEnd()
return
}
}
} else if this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
this.web.UAM = this.ReqServer.UAM
if !this.isHealthCheck && !uamIsCalled {
if this.web.UAM != nil && this.web.UAM.IsOn {
if this.doUAM() {
this.doEnd()
return
@@ -278,6 +292,16 @@ func (this *HTTPRequest) Do() {
if this.web.Compression != nil && this.web.Compression.IsOn && this.web.Compression.Level > 0 {
this.writer.SetCompression(this.web.Compression)
}
// HLS
if this.web.HLS != nil &&
this.web.HLS.Encrypting != nil &&
this.web.HLS.Encrypting.IsOn {
if this.processHLSBefore() {
this.doEnd()
return
}
}
}
// 开始调用
@@ -410,6 +434,8 @@ func (this *HTTPRequest) doEnd() {
var countAttacks int64 = 0
var attackBytes int64 = 0
var countWebsocketConnections int64 = 0
if this.isCached {
countCached = 1
cachedBytes = totalBytes
@@ -421,8 +447,11 @@ func (this *HTTPRequest) doEnd() {
attackBytes = totalBytes
}
}
if this.isWebsocketResponse {
countWebsocketConnections = 1
}
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, countWebsocketConnections, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
// 指标
if metrics.SharedManager.HasHTTPMetrics() {
@@ -624,6 +653,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
this.web.CC = web.CC
}
// HLS
if web.HLS != nil && (web.HLS.IsPrior || isTop) {
this.web.HLS = web.HLS
}
// 重写规则
if len(web.RewriteRefs) > 0 {
for index, ref := range web.RewriteRefs {
@@ -1423,11 +1457,25 @@ func (this *HTTPRequest) requestScheme() string {
// 请求的服务器地址中的端口
func (this *HTTPRequest) requestServerPort() int {
_, port, err := net.SplitHostPort(this.ServerAddr)
if err == nil {
return types.Int(port)
if len(this.ServerAddr) > 0 {
_, port, err := net.SplitHostPort(this.ServerAddr)
if err == nil && len(port) > 0 {
return types.Int(port)
}
}
return 0
var host = this.RawReq.Host
if len(host) > 0 {
_, port, err := net.SplitHostPort(host)
if err == nil && len(port) > 0 {
return types.Int(port)
}
}
if this.IsHTTP {
return 80
}
return 443
}
func (this *HTTPRequest) Id() string {

View File

@@ -43,7 +43,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
if uriChanged {
this.uri = newURI
}
this.tags = append(this.tags, ref.AuthPolicy.Type)
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
return
} else {
// Basic Auth比较特殊
@@ -64,7 +64,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
}
}
this.writer.WriteHeader(http.StatusUnauthorized)
this.tags = append(this.tags, ref.AuthPolicy.Type)
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
return true
}
}

View File

@@ -3,6 +3,7 @@ package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -130,7 +131,22 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var tags = []string{}
// 检查是否有缓存
var key = this.Format(this.cacheRef.Key)
var key string
if this.web.Cache.Key != nil && this.web.Cache.Key.IsOn && len(this.web.Cache.Key.Host) > 0 {
key = configutils.ParseVariables(this.cacheRef.Key, func(varName string) (value string) {
switch varName {
case "scheme":
return this.web.Cache.Key.Scheme
case "host":
return this.web.Cache.Key.Host
default:
return this.Format("${" + varName + "}")
}
})
} else {
key = this.Format(this.cacheRef.Key)
}
if len(key) == 0 {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, empty key"

View File

@@ -0,0 +1,16 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "net/http"
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
// stub
return false
}
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
// stub
return nil
}

View File

@@ -34,60 +34,63 @@ func (this *HTTPRequest) doMismatch() {
}
// 根据配置进行相应的处理
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
var statusCode = 404
var httpAllConfig = globalServerConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
var nodeConfig = sharedNodeConfig // copy
if nodeConfig != nil {
var globalServerConfig = nodeConfig.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
var statusCode = 404
var httpAllConfig = globalServerConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Options != nil {
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
statusCode = mismatchStatusCode
}
}
// 是否正在访问IP
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
this.writer.statusCode = statusCode
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(contentHTML)
return
}
// 检查cc
// TODO 可以在管理端配置是否开启以及最多尝试次数
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
if len(remoteIP) > 0 {
const maxAttempts = 100
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
// 在加入之前再次检查黑名单
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
if mismatchAction != nil && mismatchAction.Options != nil {
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
statusCode = mismatchStatusCode
}
}
}
// 处理当前连接
if mismatchAction != nil && mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
if mismatchAction.Options != nil {
// 是否正在访问IP
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
this.writer.statusCode = statusCode
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(contentHTML))
_, _ = this.writer.WriteString(contentHTML)
return
}
// 检查cc
// TODO 可以在管理端配置是否开启以及最多尝试次数
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
if len(remoteIP) > 0 {
const maxAttempts = 100
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
// 在加入之前再次检查黑名单
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
}
}
}
// 处理当前连接
if mismatchAction != nil && mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
if mismatchAction.Options != nil {
this.writer.statusCode = statusCode
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(contentHTML))
} else {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
}
return
} else {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
this.Close()
return
}
return
} else {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
this.Close()
return
}
}

View File

@@ -151,7 +151,7 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
}
return true
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
var newURL = page.URL
var newURL = this.Format(page.URL)
if len(newURL) == 0 {
newURL = "/"
}

View File

@@ -0,0 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
package nodes
// 检查套餐
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
// stub
return false
}

View File

@@ -1,19 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
)
// 套餐过期
func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))
}

View File

@@ -381,11 +381,20 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
return
}
if resp == nil {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
return
}
// fix Content-Type
if resp.Header["Content-Type"] == nil {
resp.Header["Content-Type"] = []string{}
}
// 40x && 50x
*failStatusCode = resp.StatusCode
if resp != nil &&
((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
if ((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
(originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) &&
!isLastRetry {
if resp.Body != nil {
@@ -397,8 +406,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
// 尝试从缓存中恢复
if resp != nil &&
resp.StatusCode >= 500 && // support 50X only
if resp.StatusCode >= 500 && // support 50X only
resp.StatusCode < 510 &&
this.cacheCanTryStale &&
this.web.Cache.Stale != nil &&
@@ -434,21 +442,43 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// Page optimization
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
err := this.web.Optimization.FilterResponse(resp)
err := this.web.Optimization.FilterResponse(this.URL(), resp)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Page Optimization: Fail to read content from origin", "内容优化:从源站读取内容失败", false)
this.write50x(err, http.StatusBadGateway, "Page Optimization: fail to read content from origin", "内容优化:从源站读取内容失败", false)
return
}
}
// HLS
if this.web.HLS != nil &&
this.web.HLS.Encrypting != nil &&
this.web.HLS.Encrypting.IsOn &&
resp.StatusCode == http.StatusOK {
m3u8Err := this.processM3u8Response(resp)
if m3u8Err != nil {
this.write50x(m3u8Err, http.StatusBadGateway, "m3u8 encrypt: fail to read content from origin", "m3u8加密从源站读取内容失败", false)
return
}
}
// 设置Charset
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
// TODO 这里应该可以设置文本类型的列表
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
contentTypes, ok := resp.Header["Content-Type"]
if ok && len(contentTypes) > 0 {
var contentType = contentTypes[0]
if this.web.Charset.Force {
var semiIndex = strings.Index(contentType, ";")
if semiIndex > 0 {
contentType = contentType[:semiIndex]
}
}
if _, found := textMimeMap[contentType]; found {
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
var newCharset = this.web.Charset.Charset
if this.web.Charset.IsUpper {
newCharset = strings.ToUpper(newCharset)
}
resp.Header["Content-Type"][0] = contentType + "; charset=" + newCharset
}
}
}

View File

@@ -3,6 +3,7 @@ package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/assert"
"net/http"
"runtime"
"testing"
)
@@ -10,26 +11,45 @@ import (
func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
var a = assert.NewAssertion(t)
{
req := &HTTPRequest{
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: NewEmptyResponseWriter(nil),
ReqServer: &serverconfigs.ServerConfig{
IsOn: true,
Web: &serverconfigs.HTTPWebConfig{
IsOn: true,
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{},
},
},
}
req.init()
req.Do()
a.IsBool(req.web.RedirectToHttps.IsOn == false)
}
{
req := &HTTPRequest{
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: NewEmptyResponseWriter(nil),
ReqServer: &serverconfigs.ServerConfig{
IsOn: true,
Web: &serverconfigs.HTTPWebConfig{
IsOn: true,
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{
IsOn: true,
},
},
},
}
req.init()
req.Do()
a.IsBool(req.web.RedirectToHttps.IsOn == true)
}

View File

@@ -7,7 +7,19 @@ import (
)
// 流量限制
func (this *HTTPRequest) doTrafficLimit() {
func (this *HTTPRequest) doTrafficLimit(status *serverconfigs.TrafficLimitStatus) (blocked bool) {
if status == nil {
return false
}
// 如果是网站单独设置的流量限制,则检查是否已关闭
var config = this.ReqServer.TrafficLimit
if (config == nil || !config.IsOn) && status.PlanId == 0 {
return false
}
// 如果是套餐设置的流量限制,即使套餐变更了(变更套餐或者变更套餐的限制),仍然会提示流量超限
this.tags = append(this.tags, "trafficLimit")
var statusCode = 509
@@ -17,10 +29,19 @@ func (this *HTTPRequest) doTrafficLimit() {
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.WriteHeader(statusCode)
var config = this.ReqServer.TrafficLimit
// check plan traffic limit
if (config == nil || !config.IsOn) && this.ReqServer.PlanId() > 0 && this.nodeConfig != nil {
var planConfig = this.nodeConfig.FindPlan(this.ReqServer.PlanId())
if planConfig != nil && planConfig.TrafficLimit != nil && planConfig.TrafficLimit.IsOn {
config = planConfig.TrafficLimit
}
}
if config != nil && len(config.NoticePageBody) != 0 {
_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
} else {
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
}
return true
}

View File

@@ -4,7 +4,13 @@
package nodes
// UAM
func (this *HTTPRequest) doUAM() (block bool) {
func (this *HTTPRequest) isUAMRequest() bool {
// stub
return false
}
// UAM
func (this *HTTPRequest) doUAM() (block bool) {
// stub
return false
}

View File

@@ -67,8 +67,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blocked {
blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blockedRequest {
return true
}
if breakChecking {
@@ -78,8 +78,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blocked {
blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
if breakChecking {
@@ -266,8 +266,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if err != nil {
@@ -277,28 +280,28 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
if ruleSet != nil {
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if ruleSet.HasSpecialActions() {
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() {
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !goNext, false
return !result.GoNext, breakChecking
}
// call response waf
@@ -316,23 +319,26 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
}
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blocked {
blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blockedRequest {
return true
}
if breakChecking {
return
}
}
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blocked {
blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
}
return
}
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool) {
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
return
}
@@ -347,8 +353,11 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
result, err := w.MatchResponse(this, resp, this.writer)
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if err != nil {
@@ -358,28 +367,28 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
if ruleSet != nil {
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if ruleSet.HasSpecialActions() {
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() {
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !goNext
return !result.GoNext, breakChecking
}
// WAFRaw 原始请求

View File

@@ -61,6 +61,9 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
}
// 标记
this.isWebsocketResponse = true
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
var newRequestOrigin = this.web.Websocket.RequestOrigin
@@ -77,7 +80,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}
// 连接源站
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
if isLastRetry {

View File

@@ -11,7 +11,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
@@ -19,7 +18,6 @@ import (
setutils "github.com/TeaOSLab/EdgeNode/internal/utils/sets"
"github.com/TeaOSLab/EdgeNode/internal/utils/writers"
_ "github.com/biessek/golang-ico"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gowebp"
_ "golang.org/x/image/bmp"
@@ -34,22 +32,19 @@ import (
"net/textproto"
"os"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
)
var webpMaxBufferSize int64 = 1_000_000_000
var webpTotalBufferSize int64 = 0
var webpIgnoreURLSet = setutils.NewFixedSet(131072)
var webPThreads int32
var webPMaxThreads int32 = 1
var webPIgnoreURLSet = setutils.NewFixedSet(131072)
func init() {
if !teaconst.IsMain {
return
}
var systemMemory = utils.SystemMemoryGB() / 8
if systemMemory > 0 {
webpMaxBufferSize = int64(systemMemory) << 30
webPMaxThreads = int32(runtime.NumCPU() / 4)
if webPMaxThreads < 1 {
webPMaxThreads = 1
}
}
@@ -80,6 +75,7 @@ type HTTPWriter struct {
// WebP
webpIsEncoding bool
webpOriginContentType string
webpQuality int
// Compression
compressionConfig *serverconfigs.HTTPCompressionConfig
@@ -343,7 +339,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial)
if err != nil {
if err == caches.ErrEntityTooLarge && addStatusHeader {
if errors.Is(err, caches.ErrEntityTooLarge) && addStatusHeader {
this.Header().Set("X-Cache", "BYPASS, entity too large")
}
@@ -483,8 +479,8 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
contentTypeWritten = true
}
err := cacheWriter.WriteAt(start, data)
if err != nil {
writeErr := cacheWriter.WriteAt(start, data)
if writeErr != nil {
hasError = true
this.cacheIsFinished = false
}
@@ -497,7 +493,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
return
}
var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter)
var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter, false)
resp.Body = cacheReader
this.rawReader = cacheReader
@@ -531,6 +527,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
if policy.RequireCache && this.req.cacheRef == nil {
return
}
this.webpQuality = policy.Quality
// 限制最小和最大尺寸
// TODO 需要将reader修改为LimitReader
@@ -550,7 +547,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
this.req.web.WebP.MatchResponse(contentType, size, filepath.Ext(this.req.Path()), this.req.Format) &&
this.req.web.WebP.MatchAccept(this.req.requestHeader("Accept")) {
// 检查是否已经因为尺寸过大而忽略
if webpIgnoreURLSet.Has(this.req.URL()) {
if webPIgnoreURLSet.Has(this.req.URL()) {
return
}
@@ -560,24 +557,24 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
return
}
// 检查内存
if atomic.LoadInt64(&webpTotalBufferSize) >= webpMaxBufferSize {
// 检查当前是否正在转换
if atomic.LoadInt32(&webPThreads) >= webPMaxThreads {
return
}
var contentEncoding = this.GetHeader("Content-Encoding")
switch contentEncoding {
case "gzip", "deflate", "br", "zstd":
reader, err := compressions.NewReader(resp.Body, contentEncoding)
if err != nil {
if len(contentEncoding) > 0 {
if compressions.SupportEncoding(contentEncoding) {
reader, err := compressions.NewReader(resp.Body, contentEncoding)
if err != nil {
return
}
this.Header().Del("Content-Encoding")
this.Header().Del("Content-Length")
this.rawReader = reader
} else {
return
}
this.Header().Del("Content-Encoding")
this.Header().Del("Content-Length")
this.rawReader = reader
case "": // 空
default:
return
}
this.webpOriginContentType = contentType
@@ -605,7 +602,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
var contentEncoding = this.GetHeader("Content-Encoding")
if this.compressionConfig == nil || !this.compressionConfig.IsOn {
if lists.ContainsString([]string{"gzip", "deflate", "br", "zstd"}, contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) {
if compressions.SupportEncoding(contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) {
reader, err := compressions.NewReader(resp.Body, contentEncoding)
if err != nil {
return
@@ -622,12 +619,12 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
return
}
if this.compressionConfig.Level <= 0 {
if this.compressionConfig.Level < 0 {
return
}
// 如果已经有编码则不处理
if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !lists.ContainsString([]string{"gzip", "deflate", "br", "zstd"}, contentEncoding)) {
if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !compressions.SupportEncoding(contentEncoding)) {
return
}
@@ -809,6 +806,8 @@ func (this *HTTPWriter) Write(data []byte) (n int, err error) {
}
n, err = this.writer.Write(data)
this.checkPlanBandwidth(n)
return
}
@@ -969,6 +968,7 @@ func (this *HTTPWriter) Close() {
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
hijack, ok := this.rawWriter.(http.Hijacker)
if ok {
this.req.isHijacked = true
return hijack.Hijack()
}
return
@@ -1020,6 +1020,11 @@ func (this *HTTPWriter) calculateStaleLife() int {
func (this *HTTPWriter) finishWebP() {
// 处理WebP
if this.webpIsEncoding {
atomic.AddInt32(&webPThreads, 1)
defer func() {
atomic.AddInt32(&webPThreads, -1)
}()
var webpCacheWriter caches.Writer
// 准备WebP Cache
@@ -1080,7 +1085,7 @@ func (this *HTTPWriter) finishWebP() {
if isGif {
gifImage, err = gif.DecodeAll(reader)
if gifImage != nil && (gifImage.Config.Width > gowebp.WebPMaxDimension || gifImage.Config.Height > gowebp.WebPMaxDimension) {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
} else {
@@ -1088,7 +1093,7 @@ func (this *HTTPWriter) finishWebP() {
if imageData != nil {
var bound = imageData.Bounds()
if bound.Max.X > gowebp.WebPMaxDimension || bound.Max.Y > gowebp.WebPMaxDimension {
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
}
@@ -1096,19 +1101,21 @@ func (this *HTTPWriter) finishWebP() {
if err != nil {
// 发生了错误终止处理
webpIgnoreURLSet.Push(this.req.URL())
webPIgnoreURLSet.Push(this.req.URL())
return
}
var totalBytes = reader.TotalBytes()
atomic.AddInt64(&webpTotalBufferSize, totalBytes)
defer func() {
atomic.AddInt64(&webpTotalBufferSize, -totalBytes)
}()
var f = types.Float32(this.req.web.WebP.Quality)
if f > 100 {
f = 100
var f = types.Float32(this.webpQuality)
if f <= 0 || f > 100 {
if this.size > (8<<20) || this.size <= 0 {
f = 30
} else if this.size > (1 << 20) {
f = 50
} else if this.size > (128 << 10) {
f = 60
} else {
f = 75
}
}
if imageData != nil {

View File

@@ -11,3 +11,7 @@ import (
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
return nil, false
}
func (this *HTTPWriter) checkPlanBandwidth(n int) {
// stub
}

View File

@@ -3,6 +3,7 @@ package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/Tea"
"io"
@@ -52,6 +53,8 @@ func (this *HTTPListener) Serve() error {
atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
default:
// do nothing
}
},
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
@@ -74,7 +77,7 @@ func (this *HTTPListener) Serve() error {
// HTTP协议
if this.isHTTP {
err := this.httpServer.Serve(this.Listener)
if err != nil && err != http.ErrServerClosed {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}
@@ -84,7 +87,7 @@ func (this *HTTPListener) Serve() error {
this.httpServer.TLSConfig = this.buildTLSConfig()
err := this.httpServer.ServeTLS(this.Listener, "", "")
if err != nil && err != http.ErrServerClosed {
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}
@@ -105,8 +108,13 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Reset()
}
// ServerHTTP 处理HTTP请求
// ServeHTTPWithAddr 处理HTTP请求
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
this.ServeHTTPWithAddr(rawWriter, rawReq, this.addr)
}
// ServeHTTPWithAddr 处理HTTP请求并指定服务地址
func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawReq *http.Request, serverAddr string) {
if len(rawReq.Host) > 253 {
http.Error(rawWriter, "Host too long.", http.StatusBadRequest)
time.Sleep(1 * time.Second) // make connection slow down
@@ -175,10 +183,12 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
}
// 绑定连接
var clientConn ClientConnInterface
if server != nil && server.Id > 0 {
var requestConn = rawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface)
var ok bool
clientConn, ok = requestConn.(ClientConnInterface)
if ok {
var goNext = clientConn.SetServerId(server.Id)
if !goNext {
@@ -211,7 +221,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
ReqServer: server,
ReqHost: reqHost,
ServerName: serverName,
ServerAddr: this.addr,
ServerAddr: serverAddr,
IsHTTP: this.isHTTP,
IsHTTPS: this.isHTTPS,
IsHTTP3: this.isHTTP3,
@@ -219,6 +229,14 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
nodeConfig: sharedNodeConfig,
}
req.Do()
// fix hijacked connection state
if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil {
netConn, ok := clientConn.(net.Conn)
if ok {
this.httpServer.ConnState(netConn, http.StateClosed)
}
}
}
// 检查host是否为IP

View File

@@ -47,9 +47,13 @@ func (this *TCPListener) Serve() error {
atomic.AddInt64(&this.countActiveConnections, 1)
go func(conn net.Conn) {
err = this.handleConn(conn)
var server = this.Group.FirstServer()
if server == nil {
return
}
err = this.handleConn(server, conn)
if err != nil {
remotelogs.Error("TCP_LISTENER", err.Error())
remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
}
atomic.AddInt64(&this.countActiveConnections, -1)
}(conn)
@@ -63,8 +67,7 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Reset()
}
func (this *TCPListener) handleConn(conn net.Conn) error {
var server = this.Group.FirstServer()
func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
if server == nil {
return errors.New("no server available")
}
@@ -132,14 +135,14 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
@@ -194,7 +197,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 记录流量
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {

View File

@@ -370,7 +370,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 统计
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
// 处理ControlMessage
@@ -401,7 +401,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
// 记录流量和带宽
if server != nil {
// 流量
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// 带宽
var userPlanId int64

View File

@@ -777,9 +777,23 @@ func (this *Node) listenSock() error {
_ = cmd.ReplyOk()
}
case "gc":
var before = time.Now()
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
var costSeconds = time.Since(before).Seconds()
var gcStats = &debug.GCStats{}
debug.ReadGCStats(gcStats)
var pauseMS float64
if len(gcStats.Pause) > 0 {
pauseMS = gcStats.Pause[0].Seconds() * 1000
}
_ = cmd.Reply(&gosock.Command{
Params: map[string]any{
"pauseMS": pauseMS,
"costMS": costSeconds * 1000,
},
})
case "reload":
err := this.syncConfig(0)
if err != nil {

View File

@@ -100,6 +100,10 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
err = this.execTOAChangedTask()
case "networkSecurityPolicyChanged":
err = this.execNetworkSecurityPolicyChangedTask(rpcClient)
case "webPPolicyChanged":
err = this.execWebPPolicyChangedTask(rpcClient)
case "planChanged":
err = this.execPlanChangedTask(rpcClient)
default:
// 特殊任务
if strings.HasPrefix(task.Type, "ipListDeleted") { // 删除IP名单
@@ -325,6 +329,34 @@ func (this *Node) execDeleteIPList(taskType string) error {
return nil
}
// WebP策略变更
func (this *Node) execWebPPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating webp policies ...")
resp, err := rpcClient.NodeRPC.FindNodeWebPPolicies(rpcClient.Context(), &pb.FindNodeWebPPoliciesRequest{})
if err != nil {
return err
}
var webPPolicyMap = map[int64]*nodeconfigs.WebPImagePolicy{}
for _, policy := range resp.WebPPolicies {
if len(policy.WebPPolicyJSON) > 0 {
var webPPolicy = nodeconfigs.NewWebPImagePolicy()
err = json.Unmarshal(policy.WebPPolicyJSON, webPPolicy)
if err != nil {
remotelogs.Error("NODE", "decode webp policy failed: "+err.Error())
continue
}
err = webPPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize webp policy failed: "+err.Error())
continue
}
webPPolicyMap[policy.NodeClusterId] = webPPolicy
}
}
sharedNodeConfig.UpdateWebPImagePolicies(webPPolicyMap)
return nil
}
// 标记任务完成
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
if taskId <= 0 {

View File

@@ -34,3 +34,7 @@ func (this *Node) execNetworkSecurityPolicyChangedTask(rpcClient *rpc.RPCClient)
// stub
return nil
}
func (this *Node) execPlanChangedTask(rpcClient *rpc.RPCClient) error {
return nil
}

View File

@@ -1,17 +1,26 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestNode_Start(t *testing.T) {
node := NewNode()
if !testutils.IsSingleTesting() {
return
}
var node = NewNode()
node.Start()
}
func TestNode_Test(t *testing.T) {
node := NewNode()
if !testutils.IsSingleTesting() {
return
}
var node = NewNode()
err := node.Test()
if err != nil {
t.Fatal(err)

View File

@@ -3,11 +3,16 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestUpgradeManager_install(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
err := NewUpgradeManager().install()
if err != nil {
t.Fatal(err)

View File

@@ -4,7 +4,7 @@ package re
type RuneMap map[rune]*RuneTree
func (this *RuneMap) Lookup(s string, caseInsensitive bool) bool {
func (this RuneMap) Lookup(s string, caseInsensitive bool) bool {
return this.lookup([]rune(s), caseInsensitive, 0)
}

View File

@@ -55,6 +55,7 @@ type RPCClient struct {
ClientAgentIPRPC pb.ClientAgentIPServiceClient
AuthorityKeyRPC pb.AuthorityKeyServiceClient
UpdatingServerListRPC pb.UpdatingServerListServiceClient
PlanRPC pb.PlanServiceClient
}
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
@@ -91,6 +92,7 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
client.ClientAgentIPRPC = pb.NewClientAgentIPServiceClient(client)
client.AuthorityKeyRPC = pb.NewAuthorityKeyServiceClient(client)
client.UpdatingServerListRPC = pb.NewUpdatingServerListServiceClient(client)
client.PlanRPC = pb.NewPlanServiceClient(client)
err := client.init()
if err != nil {

View File

@@ -5,6 +5,7 @@ package rpc_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
timeutil "github.com/iwind/TeaGo/utils/time"
"sync"
@@ -13,6 +14,10 @@ import (
)
func TestRPCConcurrentCall(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)
@@ -43,6 +48,10 @@ func TestRPCConcurrentCall(t *testing.T) {
}
func TestRPC_Retry(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)

View File

@@ -57,12 +57,13 @@ type BandwidthStat struct {
MaxBytes int64 `json:"maxBytes"`
TotalBytes int64 `json:"totalBytes"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
UserPlanId int64 `json:"userPlanId"`
CachedBytes int64 `json:"cachedBytes"`
AttackBytes int64 `json:"attackBytes"`
CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"`
CountWebsocketConnections int64 `json:"countWebsocketConnections"`
UserPlanId int64 `json:"userPlanId"`
}
// BandwidthStatManager 服务带宽统计
@@ -142,20 +143,21 @@ func (this *BandwidthStatManager) Loop() error {
}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId,
})
delete(this.m, key)
}
@@ -231,7 +233,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, s
}
// AddTraffic 添加请求数据
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64) {
var now = fasttime.Now()
var day = now.Ymd()
var timeAt = now.Round5Hi()
@@ -245,6 +247,7 @@ func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64,
stat.CountCachedRequests += countCachedRequests
stat.CountAttackRequests += countAttacks
stat.AttackBytes += attackBytes
stat.CountWebsocketConnections += countWebsocketConnections
}
this.locker.Unlock()
}

View File

@@ -53,19 +53,20 @@ func BenchmarkBandwidthStatManager_Slice(b *testing.B) {
for j := 0; j < 100; j++ {
var stat = &stats.BandwidthStat{}
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
NodeRegionId: 1,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / 2,
TotalBytes: stat.TotalBytes,
CachedBytes: stat.CachedBytes,
AttackBytes: stat.AttackBytes,
CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests,
CountWebsocketConnections: stat.CountWebsocketConnections,
NodeRegionId: 1,
})
}
_ = pbStats

View File

@@ -106,13 +106,13 @@ func (this *TrafficStatManager) Start() {
}
// Add 添加流量
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64, checkingTrafficLimit bool, planId int64) {
if serverId == 0 {
return
}
// 添加到带宽
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes)
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes, countWebsocketConnections)
if bytes == 0 && countRequests == 0 {
return

View File

@@ -11,7 +11,7 @@ import (
func TestTrafficStatManager_Add(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, 0, false, 0)
}
t.Log(manager.itemMap)
}
@@ -19,7 +19,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
func TestTrafficStatManager_Upload(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, 0, false, 0)
}
err := manager.Upload()
if err != nil {
@@ -36,7 +36,7 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, false, 0)
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, 0, false, 0)
}
})
}

View File

@@ -3,38 +3,51 @@
package stats
import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fnv"
syncutils "github.com/TeaOSLab/EdgeNode/internal/utils/sync"
"github.com/mssola/useragent"
"sync"
"time"
)
var SharedUserAgentParser = NewUserAgentParser()
const userAgentShardingCount = 8
// UserAgentParser UserAgent解析器
type UserAgentParser struct {
parser *useragent.UserAgent
cacheMaps [userAgentShardingCount]map[uint64]UserAgentParserResult
pool *sync.Pool
mu *syncutils.RWMutex
cacheMap1 map[uint64]UserAgentParserResult
cacheMap2 map[uint64]UserAgentParserResult
maxCacheItems int
cacheCursor int
locker sync.RWMutex
gcTicker *time.Ticker
gcIndex int
}
// NewUserAgentParser 获取新解析器
func NewUserAgentParser() *UserAgentParser {
var parser = &UserAgentParser{
parser: &useragent.UserAgent{},
cacheMap1: map[uint64]UserAgentParserResult{},
cacheMap2: map[uint64]UserAgentParserResult{},
cacheCursor: 0,
pool: &sync.Pool{
New: func() any {
return &useragent.UserAgent{}
},
},
cacheMaps: [userAgentShardingCount]map[uint64]UserAgentParserResult{},
mu: syncutils.NewRWMutex(userAgentShardingCount),
}
for i := 0; i < userAgentShardingCount; i++ {
parser.cacheMaps[i] = map[uint64]UserAgentParserResult{}
}
parser.init()
return parser
}
// 初始化
func (this *UserAgentParser) init() {
var maxCacheItems = 10_000
var systemMemory = utils.SystemMemoryGB()
@@ -46,8 +59,16 @@ func (this *UserAgentParser) init() {
maxCacheItems = 20_000
}
this.maxCacheItems = maxCacheItems
this.gcTicker = time.NewTicker(5 * time.Second)
goman.New(func() {
for range this.gcTicker.C {
this.GC()
}
})
}
// Parse 解析UserAgent
func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResult) {
// 限制长度
if len(userAgent) == 0 || len(userAgent) > 256 {
@@ -55,28 +76,22 @@ func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResu
}
var userAgentKey = fnv.HashString(userAgent)
var shardingIndex = int(userAgentKey % userAgentShardingCount)
this.locker.RLock()
cacheResult, ok := this.cacheMap1[userAgentKey]
this.mu.RLock(shardingIndex)
cacheResult, ok := this.cacheMaps[shardingIndex][userAgentKey]
if ok {
this.locker.RUnlock()
this.mu.RUnlock(shardingIndex)
return cacheResult
}
this.mu.RUnlock(shardingIndex)
cacheResult, ok = this.cacheMap2[userAgentKey]
if ok {
this.locker.RUnlock()
return cacheResult
}
this.locker.RUnlock()
this.locker.Lock()
defer this.locker.Unlock()
this.parser.Parse(userAgent)
result.OS = this.parser.OSInfo()
result.BrowserName, result.BrowserVersion = this.parser.Browser()
result.IsMobile = this.parser.Mobile()
var parser = this.pool.Get().(*useragent.UserAgent)
parser.Parse(userAgent)
result.OS = parser.OSInfo()
result.BrowserName, result.BrowserVersion = parser.Browser()
result.IsMobile = parser.Mobile()
this.pool.Put(parser)
// 忽略特殊字符
if len(result.BrowserName) > 0 {
@@ -87,19 +102,45 @@ func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResu
}
}
if this.cacheCursor == 0 {
this.cacheMap1[userAgentKey] = result
if len(this.cacheMap1) >= this.maxCacheItems {
this.cacheCursor = 1
this.cacheMap2 = map[uint64]UserAgentParserResult{}
}
} else {
this.cacheMap2[userAgentKey] = result
if len(this.cacheMap2) >= this.maxCacheItems {
this.cacheCursor = 0
this.cacheMap1 = map[uint64]UserAgentParserResult{}
}
}
this.mu.Lock(shardingIndex)
this.cacheMaps[shardingIndex][userAgentKey] = result
this.mu.Unlock(shardingIndex)
return
}
// MaxCacheItems 读取能容纳的缓存最大数量
func (this *UserAgentParser) MaxCacheItems() int {
return this.maxCacheItems
}
// Len 读取当前缓存数量
func (this *UserAgentParser) Len() int {
var total = 0
for i := 0; i < userAgentShardingCount; i++ {
this.mu.RLock(i)
total += len(this.cacheMaps[i])
this.mu.RUnlock(i)
}
return total
}
// GC 回收多余的缓存
func (this *UserAgentParser) GC() {
var total = this.Len()
if total > this.maxCacheItems {
for {
var shardingIndex = this.gcIndex
this.mu.Lock(shardingIndex)
total -= len(this.cacheMaps[shardingIndex])
this.cacheMaps[shardingIndex] = map[uint64]UserAgentParserResult{}
this.gcIndex = (this.gcIndex + 1) % userAgentShardingCount
this.mu.Unlock(shardingIndex)
if total <= this.maxCacheItems {
break
}
}
}
}

View File

@@ -1,17 +1,21 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package stats
package stats_test
import (
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"runtime"
"runtime/debug"
"testing"
"time"
)
func TestUserAgentParser_Parse(t *testing.T) {
var parser = NewUserAgentParser()
var parser = stats.NewUserAgentParser()
for i := 0; i < 4; i++ {
t.Log(parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/1"))
t.Log(parser.Parse("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"))
@@ -19,7 +23,7 @@ func TestUserAgentParser_Parse(t *testing.T) {
}
func TestUserAgentParser_Parse_Unknown(t *testing.T) {
var parser = NewUserAgentParser()
var parser = stats.NewUserAgentParser()
t.Log(parser.Parse("Mozilla/5.0 (Wind 10.0; WOW64; rv:49.0) Apple/537.36 (KHTML, like Gecko) Chr/88.0.4324.96 Sa/537.36 Test/1"))
t.Log(parser.Parse(""))
}
@@ -28,10 +32,10 @@ func TestUserAgentParser_Memory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var parser = NewUserAgentParser()
var parser = stats.NewUserAgentParser()
for i := 0; i < 1_000_000; i++ {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
}
runtime.GC()
@@ -40,32 +44,76 @@ func TestUserAgentParser_Memory(t *testing.T) {
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log("max cache items:", parser.maxCacheItems)
t.Log("cache1:", len(parser.cacheMap1), "cache2:", len(parser.cacheMap2), "cache3:", (stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB")
t.Log("max cache items:", parser.MaxCacheItems())
t.Log("cache:", parser.Len(), "usage:", (stat2.HeapInuse-stat1.HeapInuse)>>20, "MB")
}
func BenchmarkUserAgentParser_Parse(b *testing.B) {
var parser = NewUserAgentParser()
for i := 0; i < b.N; i++ {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
func TestNewUserAgentParser_GC(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
}
func BenchmarkUserAgentParser_Parse2(b *testing.B) {
var parser = NewUserAgentParser()
for i := 0; i < b.N; i++ {
var parser = stats.NewUserAgentParser()
for i := 0; i < 1_000_000; i++ {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
}
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
time.Sleep(60 * time.Second) // wait to gc
t.Log(parser.Len(), "cache items")
}
func BenchmarkUserAgentParser_Parse3(b *testing.B) {
var parser = NewUserAgentParser()
func TestNewUserAgentParser_Mobile(t *testing.T) {
var a = assert.NewAssertion(t)
var parser = stats.NewUserAgentParser()
for _, userAgent := range []string{
"Mozilla/5.0 (iPhone; CPU iPhone OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
"Mozilla/5.0 (Linux; U; Android 2.2.1; en-us; Nexus One Build/FRG83) AppleWebKit/533.1 (KHTML, like Gecko) Version/4.0 Mobile Safari/533.1",
} {
a.IsTrue(parser.Parse(userAgent).IsMobile)
}
}
func BenchmarkUserAgentParser_Parse_Many_LimitCPU(b *testing.B) {
runtime.GOMAXPROCS(4)
var parser = stats.NewUserAgentParser()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
}
})
b.Log(parser.Len())
}
func BenchmarkUserAgentParser_Parse_Many(b *testing.B) {
var parser = stats.NewUserAgentParser()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
}
})
b.Log(parser.Len())
}
func BenchmarkUserAgentParser_Parse_Few_LimitCPU(b *testing.B) {
runtime.GOMAXPROCS(4)
var parser = stats.NewUserAgentParser()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
}
})
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
b.Log(parser.Len())
}
func BenchmarkUserAgentParser_Parse_Few(b *testing.B) {
var parser = stats.NewUserAgentParser()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
}
})
b.Log(parser.Len())
}

View File

@@ -1,6 +1,7 @@
package ttlcache
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
@@ -8,6 +9,7 @@ import (
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"runtime"
"runtime/debug"
"strconv"
"sync/atomic"
"testing"
@@ -48,16 +50,34 @@ func TestCache_Memory(t *testing.T) {
}
var cache = NewCache[int]()
var isReady bool
testutils.StartMemoryStats(t, func() {
if !isReady {
return
}
t.Log(cache.Count(), "items")
})
var count = 20_000_000
var count = 1_000_000
if utils.SystemMemoryGB() > 4 {
count = 20_000_000
}
for i := 0; i < count; i++ {
cache.Write("a"+strconv.Itoa(i), 1, time.Now().Unix()+int64(rands.Int(0, 300)))
}
func() {
var before = time.Now()
runtime.GC()
var costSeconds = time.Since(before).Seconds()
var stats = &debug.GCStats{}
debug.ReadGCStats(stats)
t.Log("GC pause:", stats.Pause[0].Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
}()
isReady = true
t.Log(cache.Count())
time.Sleep(10 * time.Second)
@@ -105,6 +125,10 @@ func TestCache_IncreaseInt64(t *testing.T) {
}
func TestCache_Read(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
runtime.GOMAXPROCS(1)
var cache = NewCache[int](PiecesOption{Count: 32})

View File

@@ -4,12 +4,17 @@ package agents_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestNewManager(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var db = agents.NewDB(Tea.Root + "/data/agents.db")
err := db.Init()
if err != nil {

View File

@@ -4,6 +4,7 @@ package agents_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
@@ -11,6 +12,10 @@ import (
)
func TestParseQueue_Process(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var queue = agents.NewQueue()
go queue.Start()
time.Sleep(1 * time.Second)
@@ -19,6 +24,10 @@ func TestParseQueue_Process(t *testing.T) {
}
func TestParseQueue_ParseIP(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var queue = agents.NewQueue()
for _, ip := range []string{
"192.168.1.100",

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/iwind/TeaGo/Tea"
"sync"
"sync/atomic"
"time"
@@ -138,7 +139,7 @@ func (this *Stat) IsGood(category string) bool {
return true
}
if item.countCached > countSamples && item.timestamp < fasttime.Now().Unix()-600 /** 10 minutes ago **/ {
if item.countCached > countSamples && (Tea.IsTesting() || item.timestamp < fasttime.Now().Unix()-600) /** 10 minutes ago **/ {
var isGood = item.countHits*100/item.countCached >= this.goodRatio
if isGood {
item.isGood = true

View File

@@ -58,10 +58,10 @@ func TestNewStat(t *testing.T) {
{
var stat = cachehits.NewStat(5)
for i := 0; i < 10001; i++ {
for i := 0; i < 100001; i++ {
stat.IncreaseCached("a")
}
for i := 0; i < 499; i++ {
for i := 0; i < 4999; i++ {
stat.IncreaseHit("a")
}

View File

@@ -4,9 +4,14 @@ package clock_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/clock"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestReadServer(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
t.Log(clock.NewClockManager().ReadServer("pool.ntp.org"))
}

View File

@@ -22,7 +22,7 @@ type SupportedUIntType interface {
type Counter[T SupportedUIntType] struct {
countMaps uint64
locker *syncutils.RWMutex
itemMaps []map[uint64]*Item[T]
itemMaps []map[uint64]Item[T]
gcTicker *time.Ticker
gcIndex int
@@ -36,9 +36,9 @@ func NewCounter[T SupportedUIntType]() *Counter[T] {
count = 8
}
var itemMaps = []map[uint64]*Item[T]{}
var itemMaps = []map[uint64]Item[T]{}
for i := 0; i < count; i++ {
itemMaps = append(itemMaps, map[uint64]*Item[T]{})
itemMaps = append(itemMaps, map[uint64]Item[T]{})
}
var counter = &Counter[T]{
@@ -69,19 +69,22 @@ func (this *Counter[T]) WithGC() *Counter[T] {
func (this *Counter[T]) Increase(key uint64, lifeSeconds int) T {
var index = int(key % this.countMaps)
this.locker.RLock(index)
var item = this.itemMaps[index][key]
var item = this.itemMaps[index][key] // item MUST NOT be pointer
this.locker.RUnlock(index)
if item == nil {
if !item.IsOk() {
// no need to care about duplication
// always insert new item even when itemMap is full
item = NewItem[T](lifeSeconds)
var result = item.Increase()
this.locker.Lock(index)
this.itemMaps[index][key] = item
this.locker.Unlock(index)
return result
}
this.locker.Lock(index)
var result = item.Increase()
this.itemMaps[index][key] = item // overwrite
this.locker.Unlock(index)
return result
}
@@ -97,7 +100,7 @@ func (this *Counter[T]) Get(key uint64) T {
this.locker.RLock(index)
defer this.locker.RUnlock(index)
var item = this.itemMaps[index][key]
if item != nil {
if item.IsOk() {
return item.Sum()
}
return 0
@@ -115,7 +118,7 @@ func (this *Counter[T]) Reset(key uint64) {
var item = this.itemMaps[index][key]
this.locker.RUnlock(index)
if item != nil {
if item.IsOk() {
this.locker.Lock(index)
delete(this.itemMaps[index], key)
this.locker.Unlock(index)

View File

@@ -54,6 +54,7 @@ func TestCounter_GC(t *testing.T) {
time.Sleep(1 * time.Second)
counter.Increase(1, 20)
counter.GC()
t.Log(counter.Get(1))
}
func TestCounter_GC2(t *testing.T) {
@@ -62,7 +63,7 @@ func TestCounter_GC2(t *testing.T) {
}
var counter = counters.NewCounter[uint32]().WithGC()
for i := 0; i < 1e5; i++ {
for i := 0; i < 100_000; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -90,21 +91,36 @@ func TestCounterMemory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
t.Log((stat1.TotalAlloc-stat.TotalAlloc)/(1<<20), "MB")
t.Log((stat1.HeapInuse-stat.HeapInuse)/(1<<20), "MB")
t.Log(counter.TotalItems())
var gcPause = func() {
var before = time.Now()
runtime.GC()
var costSeconds = time.Since(before).Seconds()
var stats = &debug.GCStats{}
debug.ReadGCStats(stats)
t.Log("GC pause:", stats.Pause[0].Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
}
gcPause()
_ = counter.TotalItems()
}
func BenchmarkCounter_Increase(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter[uint32]()
b.ReportAllocs()
b.ResetTimer()
var i uint64
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
counter.Increase(atomic.AddUint64(&i, 1)%1e6, 20)
counter.Increase(atomic.AddUint64(&i, 1)%1_000_000, 20)
}
})
@@ -124,11 +140,12 @@ func BenchmarkCounter_IncreaseKey(b *testing.B) {
}()
b.ResetTimer()
b.ReportAllocs()
var i uint64
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
counter.IncreaseKey(types.String(atomic.AddUint64(&i, 1)%1e6), 20)
counter.IncreaseKey(types.String(atomic.AddUint64(&i, 1)%1_000_000), 20)
}
})

View File

@@ -7,28 +7,29 @@ import (
)
const spanMaxValue = 10_000_000
const maxSpans = 10
type Item[T SupportedUIntType] struct {
spans []T
spans [maxSpans + 1]T
lastUpdateTime int64
lifeSeconds int64
spanSeconds int64
}
func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
func NewItem[T SupportedUIntType](lifeSeconds int) Item[T] {
if lifeSeconds <= 0 {
lifeSeconds = 60
}
var spanSeconds = lifeSeconds / 10
var spanSeconds = lifeSeconds / maxSpans
if spanSeconds < 1 {
spanSeconds = 1
} else if lifeSeconds > maxSpans && lifeSeconds%maxSpans != 0 {
spanSeconds++
}
var countSpans = lifeSeconds/spanSeconds + 1 /** prevent index out of bounds **/
return &Item[T]{
return Item[T]{
lifeSeconds: int64(lifeSeconds),
spanSeconds: int64(spanSeconds),
spans: make([]T, countSpans),
lastUpdateTime: fasttime.Now().Unix(),
}
}
@@ -119,5 +120,13 @@ func (this *Item[T]) IsExpired(currentTime int64) bool {
}
func (this *Item[T]) calculateSpanIndex(timestamp int64) int {
return int(timestamp % this.lifeSeconds / this.spanSeconds)
var index = int(timestamp % this.lifeSeconds / this.spanSeconds)
if index > maxSpans-1 {
return maxSpans - 1
}
return index
}
func (this *Item[T]) IsOk() bool {
return this.lifeSeconds > 0
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"runtime"
"testing"
@@ -41,9 +42,9 @@ func TestItem_Increase2(t *testing.T) {
var a = assert.NewAssertion(t)
var item = counters.NewItem[uint32](20)
var item = counters.NewItem[uint32](23)
for i := 0; i < 100; i++ {
t.Log(item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
t.Log("round "+types.String(i)+":", item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
time.Sleep(2 * time.Second)
}
@@ -56,14 +57,14 @@ func TestItem_IsExpired(t *testing.T) {
return
}
var currentTime = time.Now().Unix()
var item = counters.NewItem[uint32](10)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(10 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(currentTime))
t.Log(item.IsExpired(time.Now().Unix()))
time.Sleep(2 * time.Second)
t.Log(item.IsExpired(time.Now().Unix()))
}
func BenchmarkItem_Increase(b *testing.B) {

View File

@@ -42,12 +42,12 @@ func TestIsIPv4(t *testing.T) {
func TestIsIPv6(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(utils.IsIPv6("192.168.1.1"))
a.IsFloat32(utils.IsIPv6("0.0.0.0"))
a.IsFalse(utils.IsIPv6("0.0.0.0"))
a.IsFalse(utils.IsIPv6("192.168.1.256"))
a.IsFalse(utils.IsIPv6("192.168.1"))
a.IsTrue(utils.IsIPv6("::1"))
a.IsTrue(utils.IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
a.IsTrue(utils.IsIPv4("::ffff:192.168.0.1"))
a.IsFalse(utils.IsIPv4("::ffff:192.168.0.1"))
a.IsTrue(utils.IsIPv6("::ffff:192.168.0.1"))
}

View File

@@ -2,7 +2,7 @@
package linkedlist
type List[T any] struct {
type List[T any] struct {
head *Item[T]
end *Item[T]
count int
@@ -36,6 +36,15 @@ func (this *List[T]) Push(item *Item[T]) {
this.add(item)
}
func (this *List[T]) Shift() *Item[T] {
if this.head != nil {
var old = this.head
this.Remove(this.head)
return old
}
return nil
}
func (this *List[T]) Remove(item *Item[T]) {
if item == nil {
return
@@ -71,6 +80,15 @@ func (this *List[T]) Range(f func(item *Item[T]) (goNext bool)) {
}
}
func (this *List[T]) RangeReverse(f func(item *Item[T]) (goNext bool)) {
for e := this.end; e != nil; e = e.prev {
goNext := f(e)
if !goNext {
break
}
}
}
func (this *List[T]) Reset() {
this.head = nil
this.end = nil

View File

@@ -4,6 +4,7 @@ package linkedlist_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
"github.com/iwind/TeaGo/types"
"runtime"
"strconv"
"testing"
@@ -95,6 +96,48 @@ func TestList_Push(t *testing.T) {
})
}
func TestList_Shift(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
list.Push(linkedlist.NewItem(3))
list.Push(linkedlist.NewItem(4))
for i := 0; i < 10; i++ {
t.Log("=== before shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
t.Logf("shift: %+v", list.Shift())
t.Log("=== after shift " + types.String(i) + " ===")
list.Range(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
}
func TestList_RangeReverse(t *testing.T) {
var list = linkedlist.NewList[int]()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
var item3 = linkedlist.NewItem(3)
list.Push(item3)
list.Push(linkedlist.NewItem(4))
//list.Push(item3)
//list.Remove(item3)
list.RangeReverse(func(item *linkedlist.Item[int]) (goNext bool) {
t.Log(item.Value)
return true
})
}
func BenchmarkList_Add(b *testing.B) {
var list = linkedlist.NewList[int]()
for i := 0; i < b.N; i++ {

View File

@@ -0,0 +1,61 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package ratelimit
import (
"context"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"sync/atomic"
"time"
)
// Bandwidth lossy bandwidth limiter
type Bandwidth struct {
totalBytes int64
currentTimestamp int64
currentBytes int64
}
// NewBandwidth create new bandwidth limiter
func NewBandwidth(totalBytes int64) *Bandwidth {
return &Bandwidth{totalBytes: totalBytes}
}
// Ack acquire next chance to send data
func (this *Bandwidth) Ack(ctx context.Context, newBytes int) {
if newBytes <= 0 {
return
}
if this.totalBytes <= 0 {
return
}
var timestamp = fasttime.Now().Unix()
if this.currentTimestamp != 0 && this.currentTimestamp != timestamp {
this.currentTimestamp = timestamp
this.currentBytes = int64(newBytes)
// 第一次发送直接放行,不需要判断
return
}
if this.currentTimestamp == 0 {
this.currentTimestamp = timestamp
}
if atomic.AddInt64(&this.currentBytes, int64(newBytes)) <= this.totalBytes {
return
}
var timeout = time.NewTimer(1 * time.Second)
if ctx != nil {
select {
case <-timeout.C:
case <-ctx.Done():
}
} else {
select {
case <-timeout.C:
}
}
}

View File

@@ -0,0 +1,27 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package ratelimit_test
import (
"context"
"github.com/TeaOSLab/EdgeNode/internal/utils/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestBandwidth(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var bandwidth = ratelimit.NewBandwidth(32 << 10)
bandwidth.Ack(context.Background(), 123)
bandwidth.Ack(context.Background(), 16 << 10)
bandwidth.Ack(context.Background(), 32 << 10)
}
func TestBandwidth_0(t *testing.T) {
var bandwidth = ratelimit.NewBandwidth(0)
bandwidth.Ack(context.Background(), 123)
bandwidth.Ack(context.Background(), 123456)
}

View File

@@ -1,14 +1,20 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package ratelimit
package ratelimit_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
func TestCounter_ACK(t *testing.T) {
var counter = NewCounter(10)
if !testutils.IsSingleTesting() {
return
}
var counter = ratelimit.NewCounter(10)
go func() {
for i := 0; i < 10; i++ {
@@ -26,7 +32,7 @@ func TestCounter_ACK(t *testing.T) {
}
func TestCounter_Release(t *testing.T) {
var counter = NewCounter(10)
var counter = ratelimit.NewCounter(10)
for i := 0; i < 10; i++ {
counter.Ack()
@@ -34,5 +40,5 @@ func TestCounter_Release(t *testing.T) {
for i := 0; i < 10; i++ {
counter.Release()
}
t.Log(len(counter.sem))
t.Log(counter.Len())
}

View File

@@ -12,12 +12,17 @@ type TeeReaderCloser struct {
onFail func(err error)
onEOF func()
mustWrite bool
}
func NewTeeReaderCloser(reader io.Reader, writer io.Writer) *TeeReaderCloser {
// NewTeeReaderCloser
// mustWrite - ensure writing MUST be successfully
func NewTeeReaderCloser(reader io.Reader, writer io.Writer, mustWrite bool) *TeeReaderCloser {
return &TeeReaderCloser{
r: reader,
w: writer,
r: reader,
w: writer,
mustWrite: mustWrite,
}
}
@@ -26,7 +31,13 @@ func (this *TeeReaderCloser) Read(p []byte) (n int, err error) {
if n > 0 {
_, wErr := this.w.Write(p[:n])
if (err == nil || err == io.EOF) && wErr != nil {
err = wErr
if this.mustWrite {
err = wErr
} else {
if this.onFail != nil {
this.onFail(wErr)
}
}
}
}
if err != nil {

View File

@@ -0,0 +1,170 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes
// ContainsAnyWordRunes 直接使用rune检查字符串是否包含任一单词
func ContainsAnyWordRunes(s string, words [][]rune, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, wordRunes := range words {
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAnyWord 检查字符串是否包含任一单词
func ContainsAnyWord(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
var lastRune rune // last searching rune in s
var lastIndex = -2 // -2: not started, -1: not found, >=0: rune index
for _, word := range words {
var wordRunes = []rune(word)
if len(wordRunes) == 0 {
continue
}
if lastIndex > -2 && lastRune == wordRunes[0] {
if lastIndex >= 0 {
result, _ := ContainsWordRunes(allRunes[lastIndex:], wordRunes, isCaseInsensitive)
if result {
return true
}
}
continue
} else {
result, firstIndex := ContainsWordRunes(allRunes, wordRunes, isCaseInsensitive)
lastIndex = firstIndex
if result {
return true
}
}
lastRune = wordRunes[0]
}
return false
}
// ContainsAllWords 检查字符串是否包含所有单词
func ContainsAllWords(s string, words []string, isCaseInsensitive bool) bool {
var allRunes = []rune(s)
if len(allRunes) == 0 || len(words) == 0 {
return false
}
for _, word := range words {
if result, _ := ContainsWordRunes(allRunes, []rune(word), isCaseInsensitive); !result {
return false
}
}
return true
}
// ContainsWordRunes 检查字符列表是否包含某个单词子字符列表
func ContainsWordRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) (result bool, firstIndex int) {
firstIndex = -1
var l = len(subRunes)
if l == 0 {
return false, 0
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) && (index == 0 || !isChar(allRunes[index-1]) /**boundary check **/) {
if firstIndex < 0 {
firstIndex = index
}
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found && (al <= index+l || !isChar(allRunes[index+l]) /**boundary check **/) {
return true, firstIndex
}
}
}
return false, firstIndex
}
// ContainsSubRunes 检查字符列表是否包含某个子子字符列表
// 与 ContainsWordRunes 不同,这里不需要检查边界符号
func ContainsSubRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) bool {
var l = len(subRunes)
if l == 0 {
return false
}
var al = len(allRunes)
for index, r := range allRunes {
if EqualRune(r, subRunes[0], isCaseInsensitive) {
var found = true
if l > 1 {
for i := 1; i < l; i++ {
var subIndex = index + i
if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) {
found = false
break
}
}
}
// check after charset
if found {
return true
}
}
}
return false
}
// EqualRune 判断两个rune是否相同
func EqualRune(r1 rune, r2 rune, isCaseInsensitive bool) bool {
const d = 'a' - 'A'
return r1 == r2 ||
(isCaseInsensitive && r1 >= 'a' && r1 <= 'z' && r1-r2 == d) ||
(isCaseInsensitive && r1 >= 'A' && r1 <= 'Z' && r1-r2 == -d)
}
func isChar(r rune) bool {
return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9'
}

View File

@@ -0,0 +1,172 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package runes_test
import (
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/TeaOSLab/EdgeNode/internal/utils/runes"
"github.com/iwind/TeaGo/assert"
"regexp"
"runtime"
"sort"
"strings"
"testing"
)
func TestContainsAllWords(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"are", "you"}, false))
a.IsFalse(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, false))
a.IsTrue(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, true))
}
func TestContainsAnyWord(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you", "ok"}, false))
a.IsFalse(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, false))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how"}, true))
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, true))
a.IsTrue(runes.ContainsAnyWord("How-are you?", []string{"how", "ok"}, true))
}
func TestContainsAnyWord_Sort(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"abc", "ant", "arm", "Hit", "Hi", "Pet", "pie", "are"}, false))
}
func TestContainsWordRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsWordRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("How"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("are"), false))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("re"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you w?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("w How are you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are w you?"), []rune("w"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are how you?"), []rune("how"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("ARE"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("you"), false))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true))
a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU YOU YOU YOU YOU YOU YOU"), true))
}
func TestContainsSubRunes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(runes.ContainsSubRunes([]rune(""), []rune("How"), true))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you?"), []rune(""), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("YOU"), true))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("ow"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("H"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("How"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("oi"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("g"), false))
a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("ing"), false))
a.IsFalse(runes.ContainsSubRunes([]rune("How are you doing"), []rune("int"), false))
}
func TestEqualRune(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(runes.EqualRune('a', 'a', false))
a.IsTrue(runes.EqualRune('a', 'a', true))
a.IsFalse(runes.EqualRune('a', 'A', false))
a.IsTrue(runes.EqualRune('a', 'A', true))
a.IsFalse(runes.EqualRune('c', 'C', false))
a.IsTrue(runes.EqualRune('c', 'C', true))
a.IsTrue(runes.EqualRune('C', 'C', true))
a.IsTrue(runes.EqualRune('C', 'c', true))
a.IsTrue(runes.EqualRune('Z', 'z', true))
a.IsTrue(runes.EqualRune('z', 'Z', true))
a.IsFalse(runes.EqualRune('z', 'z'+('a'-'A'), true))
}
func BenchmarkContainsWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}
func BenchmarkContainsAnyWord(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWord("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", words, true)
}
})
}
func BenchmarkContainsAnyWordRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
var words = strings.Split("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n")
sort.Strings(words)
var wordRunes = [][]rune{}
for _, word := range words {
wordRunes = append(wordRunes, []rune(word))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsAnyWordRunes("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0", wordRunes, true)
}
})
}
func BenchmarkContainsAnyWord_Regexp(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = regexp.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsAnyWord_Re(b *testing.B) {
runtime.GOMAXPROCS(4)
var reg = re.MustCompile("(?i)" + strings.ReplaceAll("python\npycurl\nhttp-client\nhttpclient\napachebench\nnethttp\nhttp_request\njava\nperl\nruby\nscrapy\nphp\nrust", "\n", "|"))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = reg.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0")
}
})
}
func BenchmarkContainsSubRunes(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = runes.ContainsSubRunes([]rune("How are you"), []rune("YOU"), true)
}
})
}

View File

@@ -1,12 +1,17 @@
package utils
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"sync"
"testing"
"time"
)
func TestRawTicker(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var ticker = time.NewTicker(2 * time.Second)
go func() {
for range ticker.C {
@@ -21,6 +26,10 @@ func TestRawTicker(t *testing.T) {
}
func TestTicker(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
ticker := NewTicker(3 * time.Second)
go func() {
time.Sleep(10 * time.Second)
@@ -33,6 +42,10 @@ func TestTicker(t *testing.T) {
}
func TestTicker2(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
ticker := NewTicker(1 * time.Second)
go func() {
time.Sleep(5 * time.Second)
@@ -50,6 +63,10 @@ func TestTicker2(t *testing.T) {
}
func TestTickerEvery(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
i := 0
wg := &sync.WaitGroup{}
wg.Add(1)
@@ -64,8 +81,11 @@ func TestTickerEvery(t *testing.T) {
wg.Wait()
}
func TestTicker_StopTwice(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
ticker := NewTicker(3 * time.Second)
go func() {
time.Sleep(10 * time.Second)

View File

@@ -5,8 +5,18 @@ import (
"net/http"
)
type AllowScope = string
const (
AllowScopeGroup AllowScope = "group"
AllowScopeServer AllowScope = "server"
AllowScopeGlobal AllowScope = "global"
)
type AllowAction struct {
BaseAction
Scope AllowScope `yaml:"scope" json:"scope"`
}
func (this *AllowAction) Init(waf *WAF) error {
@@ -25,7 +35,12 @@ func (this *AllowAction) WillChange() bool {
return true
}
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// do nothing
return true, false
return PerformResult{
ContinueRequest: true,
GoNextGroup: this.Scope == AllowScopeGroup,
IsAllowed: true,
AllowScope: this.Scope,
}
}

View File

@@ -61,7 +61,7 @@ func (this *BlockAction) WillChange() bool {
return true
}
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 加入到黑名单
var timeout = this.Timeout
if timeout <= 0 {
@@ -93,14 +93,14 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
req, err := http.NewRequest(http.MethodGet, this.URL, nil)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version)
resp, err := httpClient.Do(req)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
defer func() {
_ = resp.Body.Close()
@@ -124,11 +124,11 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
data, err := os.ReadFile(path)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
_, _ = writer.Write(data)
}
return false, false
return PerformResult{}
}
if len(this.Body) > 0 {
_, _ = writer.Write([]byte(this.Body))
@@ -137,5 +137,5 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
}
}
return false, false
return PerformResult{}
}

View File

@@ -55,6 +55,8 @@ type CaptchaAction struct {
SlideUIFooter string `yaml:"slideUIFooter" json:"slideUIFooter"` // 页脚
SlideUIBody string `yaml:"slideUIBody" json:"slideUIBody"` // 内容轮廓
GeeTestConfig *firewallconfigs.GeeTestConfig `yaml:"geeTestConfig" json:"geeTestConfig"` // 极验设置 MUST be struct
Lang string `yaml:"lang" json:"lang"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
@@ -121,10 +123,12 @@ func (this *CaptchaAction) WillChange() bool {
return true
}
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中
if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var refURL = req.WAFRaw().URL.String()
@@ -151,14 +155,17 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
info, err := utils.SimpleEncryptMap(captchaConfig)
if err != nil {
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 占用一次失败次数
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
req.DisableStat()
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false, false
return PerformResult{}
}

View File

@@ -41,15 +41,19 @@ func (this *Get302Action) WillChange() bool {
return true
}
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 仅限于Get
if request.WAFRaw().Method != http.MethodGet {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var m = maps.Map{
@@ -64,9 +68,12 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
info, err := utils.SimpleEncryptMap(m)
if err != nil {
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
}
}
request.DisableStat()
request.ProcessResponseHeaders(writer.Header(), http.StatusFound)
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
@@ -74,6 +81,6 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
if ok {
flusher.Flush()
}
return false, false
return PerformResult{}
}

View File

@@ -29,20 +29,29 @@ func (this *GoGroupAction) WillChange() bool {
return true
}
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
b, _, nextSet, err := nextGroup.MatchRequest(request)
if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
if !b {
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
return nextSet.PerformActions(waf, nextGroup, request, writer)

View File

@@ -30,23 +30,35 @@ func (this *GoSetAction) WillChange() bool {
return true
}
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId))
var nextSet = nextGroup.FindRuleSet(types.Int64(this.SetId))
if nextSet == nil || !nextSet.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
b, _, err := nextSet.MatchRequest(request)
if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
if !b {
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
return nextSet.PerformActions(waf, nextGroup, request, writer)
}

View File

@@ -27,5 +27,5 @@ type ActionInterface interface {
WillChange() bool
// Perform the action
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool)
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult
}

View File

@@ -42,15 +42,19 @@ func (this *JSCookieAction) WillChange() bool {
return true
}
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var life = this.Life
@@ -69,7 +73,9 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
var timestamp = pieces[0]
var sum = pieces[2]
if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
}
}
@@ -103,7 +109,7 @@ window.location.reload();
// 记录失败次数
this.increaseFails(req, waf.Id, group.Id, set.Id)
return false, false
return PerformResult{}
}
func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) {

View File

@@ -25,6 +25,8 @@ func (this *LogAction) WillChange() bool {
return false
}
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
return true, false
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
return PerformResult{
ContinueRequest: true,
}
}

Some files were not shown because too many files have changed in this diff Show More