Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c2d488c37 | ||
|
|
d84b844e53 | ||
|
|
095c381ae5 | ||
|
|
7d11b3c63b | ||
|
|
45ba4fe5f1 | ||
|
|
ac341da05b | ||
|
|
12b9c37095 | ||
|
|
125b25ea27 | ||
|
|
6b1d595d58 | ||
|
|
f26b80e9c1 | ||
|
|
6e1bb43a9e | ||
|
|
35e7ce1435 | ||
|
|
5eb247b999 | ||
|
|
c94e93859f | ||
|
|
d694319191 | ||
|
|
84e61f7765 | ||
|
|
196f0612dc | ||
|
|
4cd5e12686 | ||
|
|
aca128c19d | ||
|
|
5052d20fbd | ||
|
|
74790bea4a | ||
|
|
e922c12611 | ||
|
|
cc632d557b | ||
|
|
a7dd101dbf | ||
|
|
1e1cd5a643 | ||
|
|
6888ccc350 | ||
|
|
8d9a4b6d4f | ||
|
|
470b33e90f | ||
|
|
74a559b1a0 | ||
|
|
9db86b011a | ||
|
|
c94a51f47b | ||
|
|
2a9ec05d45 | ||
|
|
c9811d78a9 | ||
|
|
65435ab32d | ||
|
|
614c7a4687 | ||
|
|
ef541a2d8f | ||
|
|
9141a1434e |
@@ -6,4 +6,6 @@ if [ -z "$TAG" ]; then
|
||||
TAG="community"
|
||||
fi
|
||||
|
||||
go test -v ../... -tags=${TAG}
|
||||
# reference: https://pkg.go.dev/cmd/go/internal/test
|
||||
go clean -testcache
|
||||
go test -timeout 10s -tags="${TAG}" -cover ../...
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"gopkg.in/yaml.v3"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
@@ -30,7 +31,7 @@ func main() {
|
||||
var app = apps.NewAppCmd().
|
||||
Version(teaconst.Version).
|
||||
Product(teaconst.ProductName).
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog|uninstall]").
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|config|pprof|accesslog|uninstall]").
|
||||
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc|bandwidth|disk|cache.garbage]").
|
||||
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove|ip.close] IP")
|
||||
|
||||
@@ -524,6 +525,41 @@ func main() {
|
||||
fmt.Println("[ERROR]" + params.GetString("error"))
|
||||
}
|
||||
})
|
||||
app.On("config", func() {
|
||||
var configString = os.Args[len(os.Args)-1]
|
||||
if configString == "config" {
|
||||
fmt.Println("Usage: edge-node config '\nrpc.endpoints: [\"...\"]\nnodeId: \"...\"\nsecret: \"...\"\n'")
|
||||
return
|
||||
}
|
||||
|
||||
var config = &configs.APIConfig{}
|
||||
err := yaml.Unmarshal([]byte(configString), config)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]decode config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = config.Init()
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]validate config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// marshal again
|
||||
configYAML, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]encode config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = os.WriteFile(Tea.Root + "/configs/api_node.yaml", configYAML, 0666)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]write config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("success")
|
||||
})
|
||||
app.Run(func() {
|
||||
var node = nodes.NewNode()
|
||||
node.Start()
|
||||
|
||||
@@ -34,16 +34,14 @@ func CanIgnoreErr(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if err == ErrFileIsWriting ||
|
||||
err == ErrEntityTooLarge ||
|
||||
err == ErrWritingUnavailable ||
|
||||
err == ErrWritingQueueFull ||
|
||||
err == ErrServerIsBusy {
|
||||
if errors.Is(err, ErrFileIsWriting) ||
|
||||
errors.Is(err, ErrEntityTooLarge) ||
|
||||
errors.Is(err, ErrWritingUnavailable) ||
|
||||
errors.Is(err, ErrWritingQueueFull) ||
|
||||
errors.Is(err, ErrServerIsBusy) {
|
||||
return true
|
||||
}
|
||||
_, ok := err.(*CapacityError)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
var capacityErr *CapacityError
|
||||
return errors.As(err, &capacityErr)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package caches
|
||||
package caches_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCanIgnoreErr(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
a.IsTrue(CanIgnoreErr(ErrFileIsWriting))
|
||||
a.IsTrue(CanIgnoreErr(NewCapacityError("over capcity")))
|
||||
a.IsFalse(CanIgnoreErr(ErrNotFound))
|
||||
a.IsTrue(caches.CanIgnoreErr(caches.ErrFileIsWriting))
|
||||
a.IsTrue(caches.CanIgnoreErr(fmt.Errorf("error: %w", caches.ErrFileIsWriting)))
|
||||
a.IsTrue(errors.Is(fmt.Errorf("error: %w", caches.ErrFileIsWriting), caches.ErrFileIsWriting))
|
||||
a.IsTrue(errors.Is(caches.ErrFileIsWriting, caches.ErrFileIsWriting))
|
||||
a.IsTrue(caches.CanIgnoreErr(caches.NewCapacityError("over capacity")))
|
||||
a.IsTrue(caches.CanIgnoreErr(fmt.Errorf("error: %w", caches.NewCapacityError("over capacity"))))
|
||||
a.IsFalse(caches.CanIgnoreErr(caches.ErrNotFound))
|
||||
a.IsFalse(caches.CanIgnoreErr(errors.New("test error")))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package caches_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -18,7 +19,11 @@ func TestItems_Memory(t *testing.T) {
|
||||
var memory1 = stat.HeapInuse
|
||||
|
||||
var items = []*caches.Item{}
|
||||
for i := 0; i < 10_000_000; i++ {
|
||||
var count = 100
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 10_000_000
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
items = append(items, &caches.Item{
|
||||
Key: types.String(i),
|
||||
})
|
||||
@@ -33,7 +38,9 @@ func TestItems_Memory(t *testing.T) {
|
||||
var memory3 = stat.HeapInuse
|
||||
t.Log(memory2, memory3, (memory3-memory2)/1024/1024, "M")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
if testutils.IsSingleTesting() {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestItems_Memory2(t *testing.T) {
|
||||
@@ -42,7 +49,12 @@ func TestItems_Memory2(t *testing.T) {
|
||||
var memory1 = stat.HeapInuse
|
||||
|
||||
var items = map[int32]map[string]zero.Zero{}
|
||||
for i := 0; i < 10_000_000; i++ {
|
||||
var count = 100
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 10_000_000
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
var week = int32((time.Now().Unix() - int64(86400*rands.Int(0, 300))) / (86400 * 7))
|
||||
m, ok := items[week]
|
||||
if !ok {
|
||||
@@ -57,7 +69,9 @@ func TestItems_Memory2(t *testing.T) {
|
||||
|
||||
t.Log(memory1, memory2, (memory2-memory1)/1024/1024, "M")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
if testutils.IsSingleTesting() {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
for w, i := range items {
|
||||
t.Log(w, len(i))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package caches_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
@@ -86,6 +87,10 @@ func TestFileListHashMap_BigInt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileListHashMap_Load(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
|
||||
|
||||
defer func() {
|
||||
|
||||
@@ -17,6 +17,10 @@ import (
|
||||
)
|
||||
|
||||
func TestFileList_Init(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
@@ -34,6 +38,10 @@ func TestFileList_Init(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Add(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
|
||||
|
||||
defer func() {
|
||||
@@ -107,6 +115,10 @@ func TestFileList_Add_Many(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Exist(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
|
||||
defer func() {
|
||||
_ = list.Close()
|
||||
@@ -143,6 +155,10 @@ func TestFileList_Exist(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Exist_Many_DB(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
// 测试在多个数据库下的性能
|
||||
var listSlice = []caches.ListInterface{}
|
||||
for i := 1; i <= 10; i++ {
|
||||
@@ -202,6 +218,10 @@ func TestFileList_Exist_Many_DB(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_CleanPrefix(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
@@ -222,6 +242,10 @@ func TestFileList_CleanPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Remove(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
|
||||
defer func() {
|
||||
_ = list.Close()
|
||||
@@ -246,6 +270,10 @@ func TestFileList_Remove(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Purge(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
@@ -270,6 +298,10 @@ func TestFileList_Purge(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_PurgeLFU(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
@@ -294,6 +326,10 @@ func TestFileList_PurgeLFU(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Stat(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
@@ -313,6 +349,10 @@ func TestFileList_Stat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_Count(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data")
|
||||
|
||||
defer func() {
|
||||
@@ -333,6 +373,10 @@ func TestFileList_Count(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_CleanAll(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data")
|
||||
|
||||
defer func() {
|
||||
@@ -352,6 +396,10 @@ func TestFileList_CleanAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileList_UpgradeV3(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p43").(*caches.FileList)
|
||||
|
||||
defer func() {
|
||||
@@ -376,6 +424,10 @@ func TestFileList_UpgradeV3(t *testing.T) {
|
||||
}
|
||||
|
||||
func BenchmarkFileList_Exist(b *testing.B) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
|
||||
defer func() {
|
||||
|
||||
@@ -400,7 +400,19 @@ func (this *MemoryList) IncreaseHit(hash string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *MemoryList) print(t *testing.T) {
|
||||
func (this *MemoryList) Prefixes() []string {
|
||||
return this.prefixes
|
||||
}
|
||||
|
||||
func (this *MemoryList) ItemMaps() map[string]map[string]*Item {
|
||||
return this.itemMaps
|
||||
}
|
||||
|
||||
func (this *MemoryList) PurgeIndex() int {
|
||||
return this.purgeIndex
|
||||
}
|
||||
|
||||
func (this *MemoryList) Print(t *testing.T) {
|
||||
this.locker.Lock()
|
||||
for _, itemMap := range this.itemMaps {
|
||||
if len(itemMap) > 0 {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package caches
|
||||
package caches_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/cespare/xxhash"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -15,65 +17,65 @@ import (
|
||||
)
|
||||
|
||||
func TestMemoryList_Add(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
list := caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Init()
|
||||
_ = list.Add("a", &Item{
|
||||
_ = list.Add("a", &caches.Item{
|
||||
Key: "a1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("b", &Item{
|
||||
_ = list.Add("b", &caches.Item{
|
||||
Key: "b1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("123456", &Item{
|
||||
_ = list.Add("123456", &caches.Item{
|
||||
Key: "c1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
t.Log(list.prefixes)
|
||||
logs.PrintAsJSON(list.itemMaps, t)
|
||||
t.Log(list.Prefixes())
|
||||
logs.PrintAsJSON(list.ItemMaps(), t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Remove(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
list := caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Init()
|
||||
_ = list.Add("a", &Item{
|
||||
_ = list.Add("a", &caches.Item{
|
||||
Key: "a1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("b", &Item{
|
||||
_ = list.Add("b", &caches.Item{
|
||||
Key: "b1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Remove("b")
|
||||
list.print(t)
|
||||
list.Print(t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Purge(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
list := caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Init()
|
||||
_ = list.Add("a", &Item{
|
||||
_ = list.Add("a", &caches.Item{
|
||||
Key: "a1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("b", &Item{
|
||||
_ = list.Add("b", &caches.Item{
|
||||
Key: "b1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("c", &Item{
|
||||
_ = list.Add("c", &caches.Item{
|
||||
Key: "c1",
|
||||
ExpiredAt: time.Now().Unix() - 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("d", &Item{
|
||||
_ = list.Add("d", &caches.Item{
|
||||
Key: "d1",
|
||||
ExpiredAt: time.Now().Unix() - 2,
|
||||
HeaderSize: 1024,
|
||||
@@ -82,25 +84,30 @@ func TestMemoryList_Purge(t *testing.T) {
|
||||
t.Log("delete:", hash)
|
||||
return nil
|
||||
})
|
||||
list.print(t)
|
||||
list.Print(t)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, _ = list.Purge(100, func(hash string) error {
|
||||
t.Log("delete:", hash)
|
||||
return nil
|
||||
})
|
||||
t.Log(list.purgeIndex)
|
||||
t.Log(list.PurgeIndex())
|
||||
}
|
||||
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Purge_Large_List(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
var list = caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Init()
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
_ = list.Add("a"+strconv.Itoa(i), &Item{
|
||||
var count = 100
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 1_000_000
|
||||
}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
_ = list.Add("a"+strconv.Itoa(i), &caches.Item{
|
||||
Key: "a" + strconv.Itoa(i),
|
||||
ExpiredAt: time.Now().Unix() + int64(rands.Int(0, 24*3600)),
|
||||
HeaderSize: 1024,
|
||||
@@ -113,43 +120,46 @@ func TestMemoryList_Purge_Large_List(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryList_Stat(t *testing.T) {
|
||||
list := NewMemoryList()
|
||||
list := caches.NewMemoryList()
|
||||
_ = list.Init()
|
||||
_ = list.Add("a", &Item{
|
||||
_ = list.Add("a", &caches.Item{
|
||||
Key: "a1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("b", &Item{
|
||||
_ = list.Add("b", &caches.Item{
|
||||
Key: "b1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("c", &Item{
|
||||
_ = list.Add("c", &caches.Item{
|
||||
Key: "c1",
|
||||
ExpiredAt: time.Now().Unix(),
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Add("d", &Item{
|
||||
_ = list.Add("d", &caches.Item{
|
||||
Key: "d1",
|
||||
ExpiredAt: time.Now().Unix() - 2,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
result, _ := list.Stat(func(hash string) bool {
|
||||
// 随机测试
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
return rand.Int()%2 == 0
|
||||
})
|
||||
t.Log(result)
|
||||
}
|
||||
|
||||
func TestMemoryList_CleanPrefix(t *testing.T) {
|
||||
list := NewMemoryList()
|
||||
list := caches.NewMemoryList()
|
||||
_ = list.Init()
|
||||
before := time.Now()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var count = 100
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 1_000_000
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
key := "https://www.teaos.cn/hello/" + strconv.Itoa(i/10000) + "/" + strconv.Itoa(i) + ".html"
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &caches.Item{
|
||||
Key: key,
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
BodySize: 0,
|
||||
@@ -174,7 +184,12 @@ func TestMemoryList_CleanPrefix(t *testing.T) {
|
||||
func TestMapRandomDelete(t *testing.T) {
|
||||
var countMap = map[int]int{} // k => count
|
||||
|
||||
for j := 0; j < 1_000_000; j++ {
|
||||
var count = 1000
|
||||
if testutils.IsSingleTesting() {
|
||||
count = 1_000_000
|
||||
}
|
||||
|
||||
for j := 0; j < count; j++ {
|
||||
var m = map[int]bool{}
|
||||
for i := 0; i < 100; i++ {
|
||||
m[i] = true
|
||||
@@ -203,18 +218,18 @@ func TestMapRandomDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryList_PurgeLFU(t *testing.T) {
|
||||
var list = NewMemoryList().(*MemoryList)
|
||||
var list = caches.NewMemoryList().(*caches.MemoryList)
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
_ = list.Add("1", &Item{})
|
||||
_ = list.Add("2", &Item{})
|
||||
_ = list.Add("3", &Item{})
|
||||
_ = list.Add("4", &Item{})
|
||||
_ = list.Add("5", &Item{})
|
||||
_ = list.Add("1", &caches.Item{})
|
||||
_ = list.Add("2", &caches.Item{})
|
||||
_ = list.Add("3", &caches.Item{})
|
||||
_ = list.Add("4", &caches.Item{})
|
||||
_ = list.Add("5", &caches.Item{})
|
||||
|
||||
//_ = list.IncreaseHit("1")
|
||||
//_ = list.IncreaseHit("2")
|
||||
@@ -245,29 +260,32 @@ func TestMemoryList_PurgeLFU(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryList_CleanAll(t *testing.T) {
|
||||
var list = NewMemoryList().(*MemoryList)
|
||||
_ = list.Add("a", &Item{})
|
||||
var list = caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Add("a", &caches.Item{})
|
||||
_ = list.CleanAll()
|
||||
logs.PrintAsJSON(list.itemMaps, t)
|
||||
logs.PrintAsJSON(list.ItemMaps(), t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_GC(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
list := caches.NewMemoryList().(*caches.MemoryList)
|
||||
_ = list.Init()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
key := "https://www.teaos.cn/hello" + strconv.Itoa(i/100000) + "/" + strconv.Itoa(i) + ".html"
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &caches.Item{
|
||||
Key: key,
|
||||
ExpiredAt: 0,
|
||||
BodySize: 0,
|
||||
HeaderSize: 0,
|
||||
})
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
t.Log("clean...", len(list.itemMaps))
|
||||
t.Log("clean...", len(list.ItemMaps()))
|
||||
_ = list.CleanAll()
|
||||
t.Log("cleanAll...", len(list.itemMaps))
|
||||
t.Log("cleanAll...", len(list.ItemMaps()))
|
||||
before := time.Now()
|
||||
//runtime.GC()
|
||||
t.Log("gc cost:", time.Since(before).Seconds()*1000, "ms")
|
||||
@@ -280,3 +298,30 @@ func TestMemoryList_GC(t *testing.T) {
|
||||
time.Sleep(30 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMemoryList(b *testing.B) {
|
||||
var list = caches.NewMemoryList()
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
_ = list.Add(stringutil.Md5(types.String(i)), &caches.Item{
|
||||
Key: "a1",
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, _ = list.Exist(types.String("a" + types.String(rands.Int(1, 10000))))
|
||||
_ = list.Add("a"+types.String(rands.Int(1, 100000)), &caches.Item{})
|
||||
_, _ = list.Purge(1000, func(hash string) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,6 +111,15 @@ func (this *FileStorage) Policy() *serverconfigs.HTTPCachePolicy {
|
||||
|
||||
// CanUpdatePolicy 检查策略是否可以更新
|
||||
func (this *FileStorage) CanUpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy) bool {
|
||||
if newPolicy == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查类型
|
||||
if newPolicy.Type != serverconfigs.CachePolicyStorageFile {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查路径是否有变化
|
||||
oldOptionsJSON, err := json.Marshal(this.policy.Options)
|
||||
if err != nil {
|
||||
@@ -455,7 +464,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
|
||||
_, ok := sharedWritingFileKeyMap[key]
|
||||
if ok {
|
||||
sharedWritingFileKeyLocker.Unlock()
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w(001)", ErrFileIsWriting)
|
||||
}
|
||||
|
||||
if !isFlushing && !fsutils.WriteReady() {
|
||||
@@ -496,7 +505,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
|
||||
// 检查两次写入缓存的时间是否过于相近,分片内容不受此限制
|
||||
if err == nil && !isPartial && time.Since(stat.ModTime()) <= 1*time.Second {
|
||||
// 防止并发连续写入
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w(002)", ErrFileIsWriting)
|
||||
}
|
||||
|
||||
// 构造文件名
|
||||
@@ -597,7 +606,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, hea
|
||||
err = syscall.Flock(int(writer.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err != nil {
|
||||
removeOnFailure = false
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w (003)", ErrFileIsWriting)
|
||||
}
|
||||
|
||||
var metaBodySize int64 = -1
|
||||
|
||||
@@ -19,6 +19,10 @@ import (
|
||||
)
|
||||
|
||||
func TestFileStorage_Init(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -49,6 +53,10 @@ func TestFileStorage_Init(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_OpenWriter(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -96,6 +104,10 @@ func TestFileStorage_OpenWriter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_OpenWriter_Partial(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 2,
|
||||
IsOn: true,
|
||||
@@ -134,6 +146,10 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -202,6 +218,10 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -231,7 +251,7 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
|
||||
|
||||
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200, -1, -1, -1, false)
|
||||
if err != nil {
|
||||
if err != ErrFileIsWriting {
|
||||
if errors.Is(err, ErrFileIsWriting) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -260,6 +280,10 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -289,7 +313,7 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
|
||||
|
||||
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200, -1, -1, -1, false)
|
||||
if err != nil {
|
||||
if err != ErrFileIsWriting {
|
||||
if errors.Is(err, ErrFileIsWriting) {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
@@ -319,6 +343,10 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Read(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -358,6 +386,10 @@ func TestFileStorage_Read(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Read_HTTP_Response(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -414,6 +446,10 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Read_NotFound(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -450,6 +486,10 @@ func TestFileStorage_Read_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Delete(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -472,6 +512,10 @@ func TestFileStorage_Delete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Stat(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -500,6 +544,10 @@ func TestFileStorage_Stat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_CleanAll(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -534,6 +582,10 @@ func TestFileStorage_CleanAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_Stop(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -552,6 +604,10 @@ func TestFileStorage_Stop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_DecodeFile(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
|
||||
Id: 1,
|
||||
IsOn: true,
|
||||
@@ -571,6 +627,10 @@ func TestFileStorage_DecodeFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFileStorage_RemoveCacheFile(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewFileStorage(nil)
|
||||
|
||||
defer storage.Stop()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package caches
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -161,7 +162,7 @@ func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, h
|
||||
|
||||
// TODO 内存缓存暂时不支持分块内容存储
|
||||
if isPartial {
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w (004)", ErrFileIsWriting)
|
||||
}
|
||||
return this.openWriter(key, expiredAt, status, headerSize, bodySize, maxSize, true)
|
||||
}
|
||||
@@ -187,7 +188,7 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, h
|
||||
var isWriting = false
|
||||
_, ok := this.writingKeyMap[key]
|
||||
if ok {
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w (005)", ErrFileIsWriting)
|
||||
}
|
||||
this.writingKeyMap[key] = zero.New()
|
||||
defer func() {
|
||||
@@ -208,7 +209,7 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, h
|
||||
_ = this.list.Remove(hashString)
|
||||
item = nil
|
||||
} else {
|
||||
return nil, ErrFileIsWriting
|
||||
return nil, fmt.Errorf("%w (006)", ErrFileIsWriting)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,7 +367,7 @@ func (this *MemoryStorage) UpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy
|
||||
|
||||
// CanUpdatePolicy 检查策略是否可以更新
|
||||
func (this *MemoryStorage) CanUpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy) bool {
|
||||
return true
|
||||
return newPolicy != nil && newPolicy.Type == serverconfigs.CachePolicyStorageMemory
|
||||
}
|
||||
|
||||
// AddToList 将缓存添加到列表
|
||||
|
||||
@@ -3,6 +3,7 @@ package caches
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"runtime"
|
||||
@@ -271,6 +272,10 @@ func TestMemoryStorage_Purge(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Expire(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
|
||||
MemoryAutoPurgeInterval: 5,
|
||||
}, nil)
|
||||
|
||||
@@ -125,7 +125,6 @@ func (this *MemoryWriter) Close() error {
|
||||
// 需要在Locker之外
|
||||
defer this.once.Do(func() {
|
||||
this.endFunc(this.item)
|
||||
this.item = nil // free memory
|
||||
})
|
||||
|
||||
if this.item == nil {
|
||||
@@ -164,7 +163,6 @@ func (this *MemoryWriter) Discard() error {
|
||||
// 需要在Locker之外
|
||||
defer this.once.Do(func() {
|
||||
this.endFunc(this.item)
|
||||
this.item = nil // free memory
|
||||
})
|
||||
|
||||
this.storage.locker.Lock()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -49,3 +50,45 @@ func TestBrotliReader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBrotliReader(b *testing.B) {
|
||||
data, err := os.ReadFile("./reader_brotli.go")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
var buf = bytes.NewBuffer([]byte{})
|
||||
writer, err := compressions.NewBrotliWriter(buf, 5)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_, err = writer.Write(data)
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
var compressedData = buf.Bytes()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
reader, readerErr := compressions.NewBrotliReader(bytes.NewBuffer(compressedData))
|
||||
if readerErr != nil {
|
||||
b.Fatal(readerErr)
|
||||
}
|
||||
var readBuf = make([]byte, 1024)
|
||||
for {
|
||||
_, readErr := reader.Read(readBuf)
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
b.Fatal(readErr)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
closeErr := reader.Close()
|
||||
if closeErr != nil {
|
||||
b.Fatal(closeErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ContentEncoding = string
|
||||
@@ -58,3 +59,32 @@ func NewWriter(writer io.Writer, compressType serverconfigs.HTTPCompressionType,
|
||||
}
|
||||
return nil, errors.New("invalid compression type '" + compressType + "'")
|
||||
}
|
||||
|
||||
// SupportEncoding 检查是否支持某个编码
|
||||
func SupportEncoding(encoding string) bool {
|
||||
return encoding == ContentEncodingBr ||
|
||||
encoding == ContentEncodingGzip ||
|
||||
encoding == ContentEncodingDeflate ||
|
||||
encoding == ContentEncodingZSTD
|
||||
}
|
||||
|
||||
// WrapHTTPResponse 包装http.Response对象
|
||||
func WrapHTTPResponse(resp *http.Response) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var contentEncoding = resp.Header.Get("Content-Encoding")
|
||||
if len(contentEncoding) == 0 || !SupportEncoding(contentEncoding) {
|
||||
return
|
||||
}
|
||||
|
||||
reader, err := NewReader(resp.Body, contentEncoding)
|
||||
if err != nil {
|
||||
// unable to decode, we ignore the error
|
||||
return
|
||||
}
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-Length")
|
||||
resp.Body = reader
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ const oldConfigFileName = "api.yaml"
|
||||
|
||||
type APIConfig struct {
|
||||
OldRPC struct {
|
||||
Endpoints []string `yaml:"endpoints" json:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate" json:"disableUpdate"`
|
||||
Endpoints []string `yaml:"endpoints,omitempty" json:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate,omitempty" json:"disableUpdate"`
|
||||
} `yaml:"rpc,omitempty" json:"rpc"`
|
||||
|
||||
RPCEndpoints []string `yaml:"rpc.endpoints,flow" json:"rpc.endpoints"`
|
||||
|
||||
@@ -4,11 +4,16 @@ package configs_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadClusterConfig(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
config, err := configs.LoadClusterConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "1.3.2"
|
||||
Version = "1.3.3"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -3,11 +3,16 @@ package iplibrary
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPAPIAction_AddItem(t *testing.T) {
|
||||
action := NewHTTPAPIAction()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var action = NewHTTPAPIAction()
|
||||
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
|
||||
URL: "http://127.0.0.1:2345/post",
|
||||
TimeoutSeconds: 0,
|
||||
@@ -24,7 +29,11 @@ func TestHTTPAPIAction_AddItem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPAPIAction_DeleteItem(t *testing.T) {
|
||||
action := NewHTTPAPIAction()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var action = NewHTTPAPIAction()
|
||||
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
|
||||
URL: "http://127.0.0.1:2345/post",
|
||||
TimeoutSeconds: 0,
|
||||
|
||||
@@ -4,13 +4,19 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPSetAction_Init(t *testing.T) {
|
||||
action := iplibrary.NewIPSetAction()
|
||||
_, lookupErr := executils.LookPath("iptables")
|
||||
if lookupErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action = iplibrary.NewIPSetAction()
|
||||
err := action.Init(&firewallconfigs.FirewallActionConfig{
|
||||
Params: maps.Map{
|
||||
"path": "/usr/bin/iptables",
|
||||
@@ -25,6 +31,11 @@ func TestIPSetAction_Init(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPSetAction_AddItem(t *testing.T) {
|
||||
_, lookupErr := executils.LookPath("iptables")
|
||||
if lookupErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action = iplibrary.NewIPSetAction()
|
||||
action.SetConfig(&firewallconfigs.FirewallActionIPSetConfig{
|
||||
Path: "/usr/bin/iptables",
|
||||
@@ -84,7 +95,12 @@ func TestIPSetAction_AddItem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPSetAction_DeleteItem(t *testing.T) {
|
||||
action := iplibrary.NewIPSetAction()
|
||||
_, lookupErr := executils.LookPath("firewalld")
|
||||
if lookupErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action = iplibrary.NewIPSetAction()
|
||||
err := action.Init(&firewallconfigs.FirewallActionConfig{
|
||||
Params: maps.Map{
|
||||
"path": "/usr/bin/firewalld",
|
||||
|
||||
@@ -3,12 +3,18 @@ package iplibrary
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPTablesAction_AddItem(t *testing.T) {
|
||||
action := NewIPTablesAction()
|
||||
_, lookupErr := executils.LookPath("iptables")
|
||||
if lookupErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action = NewIPTablesAction()
|
||||
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
|
||||
Path: "/usr/bin/iptables",
|
||||
}
|
||||
@@ -40,7 +46,12 @@ func TestIPTablesAction_AddItem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPTablesAction_DeleteItem(t *testing.T) {
|
||||
action := NewIPTablesAction()
|
||||
_, lookupErr := executils.LookPath("firewalld")
|
||||
if lookupErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action = NewIPTablesAction()
|
||||
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
|
||||
Path: "/usr/bin/firewalld",
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestActionManager_UpdateActions(t *testing.T) {
|
||||
manager := NewActionManager()
|
||||
var manager = NewActionManager()
|
||||
manager.UpdateActions([]*firewallconfigs.FirewallActionConfig{
|
||||
{
|
||||
Id: 1,
|
||||
|
||||
@@ -3,11 +3,16 @@ package iplibrary
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestScriptAction_AddItem(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
action := NewScriptAction()
|
||||
action.config = &firewallconfigs.FirewallActionScriptConfig{
|
||||
Path: "/tmp/ip-item.sh",
|
||||
@@ -27,6 +32,10 @@ func TestScriptAction_AddItem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestScriptAction_DeleteItem(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
action := NewScriptAction()
|
||||
action.config = &firewallconfigs.FirewallActionScriptConfig{
|
||||
Path: "/tmp/ip-item.sh",
|
||||
|
||||
@@ -2,6 +2,7 @@ package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -75,8 +76,14 @@ func TestIPItem_Contains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPItem_Memory(t *testing.T) {
|
||||
var isSingleTest = testutils.IsSingleTesting()
|
||||
|
||||
var list = NewIPList()
|
||||
for i := 0; i < 2_000_000; i ++ {
|
||||
var count = 100
|
||||
if isSingleTest {
|
||||
count = 2_000_000
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
list.Add(&IPItem{
|
||||
Type: "ip",
|
||||
Id: uint64(i),
|
||||
@@ -87,7 +94,9 @@ func TestIPItem_Memory(t *testing.T) {
|
||||
})
|
||||
}
|
||||
t.Log("waiting")
|
||||
time.Sleep(10 * time.Second)
|
||||
if isSingleTest {
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIPItem_Contains(b *testing.B) {
|
||||
@@ -105,4 +114,3 @@ func BenchmarkIPItem_Contains(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ func TestIPListDB_AddItem(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
err = db.AddItem(&pb.IPItem{
|
||||
Id: 1,
|
||||
@@ -60,6 +63,9 @@ func TestIPListDB_ReadItems(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
@@ -77,6 +83,9 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
t.Log(db.ReadMaxVersion())
|
||||
}
|
||||
|
||||
@@ -85,6 +94,10 @@ func TestIPListDB_UpdateMaxVersion(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
err = db.UpdateMaxVersion(1027)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -3,12 +3,17 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPIsAllowed(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var manager = NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
|
||||
@@ -2,13 +2,18 @@ package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListManager_init(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var manager = NewIPListManager()
|
||||
manager.init()
|
||||
t.Log(manager.listMap)
|
||||
t.Log(SharedServerListManager.blackMap)
|
||||
@@ -16,7 +21,11 @@ func TestIPListManager_init(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPListManager_check(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var manager = NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
@@ -28,7 +37,11 @@ func TestIPListManager_check(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPListManager_loop(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var manager = NewIPListManager()
|
||||
manager.Start()
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ package monitor
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc/status"
|
||||
@@ -12,6 +13,10 @@ import (
|
||||
)
|
||||
|
||||
func TestValueQueue_RPC(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package nodes
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPIStream_Start(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
apiStream := NewAPIStream()
|
||||
apiStream.Start()
|
||||
}
|
||||
|
||||
@@ -110,6 +110,10 @@ func TestHTTPAccessLogQueue_Push2(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPAccessLogQueue_Memory(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
testutils.StartMemoryStats(t)
|
||||
|
||||
debug.SetGCPercent(10)
|
||||
|
||||
@@ -3,13 +3,14 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHTTPClientPool_Client(t *testing.T) {
|
||||
pool := NewHTTPClientPool()
|
||||
var pool = NewHTTPClientPool()
|
||||
|
||||
{
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
@@ -54,7 +55,10 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log("get", i)
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if testutils.IsSingleTesting() {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,6 +102,8 @@ type HTTPRequest struct {
|
||||
disableLog bool // 是否在当前请求中关闭Log
|
||||
forceLog bool // 是否强制记录日志
|
||||
|
||||
isHijacked bool
|
||||
|
||||
// script相关操作
|
||||
isDone bool
|
||||
}
|
||||
@@ -193,17 +195,35 @@ func (this *HTTPRequest) Do() {
|
||||
}
|
||||
|
||||
// 套餐
|
||||
if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() {
|
||||
this.doPlanExpires()
|
||||
this.doEnd()
|
||||
return
|
||||
if this.ReqServer.UserPlan != nil {
|
||||
if this.doPlanBefore() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 流量限制
|
||||
if this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
|
||||
this.doTrafficLimit()
|
||||
this.doEnd()
|
||||
return
|
||||
if this.doTrafficLimit(this.ReqServer.TrafficLimitStatus) {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// UAM
|
||||
var uamIsCalled = false
|
||||
if !this.isHealthCheck {
|
||||
if this.web.UAM == nil && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
|
||||
this.web.UAM = this.ReqServer.UAM
|
||||
}
|
||||
|
||||
if this.web.UAM != nil && this.web.UAM.IsOn && this.isUAMRequest() {
|
||||
uamIsCalled = true
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WAF
|
||||
@@ -215,16 +235,8 @@ func (this *HTTPRequest) Do() {
|
||||
}
|
||||
|
||||
// UAM
|
||||
if !this.isHealthCheck {
|
||||
if this.web.UAM != nil {
|
||||
if this.web.UAM.IsOn {
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
|
||||
this.web.UAM = this.ReqServer.UAM
|
||||
if !this.isHealthCheck && !uamIsCalled {
|
||||
if this.web.UAM != nil && this.web.UAM.IsOn {
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
@@ -280,6 +292,16 @@ func (this *HTTPRequest) Do() {
|
||||
if this.web.Compression != nil && this.web.Compression.IsOn && this.web.Compression.Level > 0 {
|
||||
this.writer.SetCompression(this.web.Compression)
|
||||
}
|
||||
|
||||
// HLS
|
||||
if this.web.HLS != nil &&
|
||||
this.web.HLS.Encrypting != nil &&
|
||||
this.web.HLS.Encrypting.IsOn {
|
||||
if this.processHLSBefore() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 开始调用
|
||||
@@ -631,6 +653,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
this.web.CC = web.CC
|
||||
}
|
||||
|
||||
// HLS
|
||||
if web.HLS != nil && (web.HLS.IsPrior || isTop) {
|
||||
this.web.HLS = web.HLS
|
||||
}
|
||||
|
||||
// 重写规则
|
||||
if len(web.RewriteRefs) > 0 {
|
||||
for index, ref := range web.RewriteRefs {
|
||||
@@ -1430,11 +1457,25 @@ func (this *HTTPRequest) requestScheme() string {
|
||||
|
||||
// 请求的服务器地址中的端口
|
||||
func (this *HTTPRequest) requestServerPort() int {
|
||||
_, port, err := net.SplitHostPort(this.ServerAddr)
|
||||
if err == nil {
|
||||
return types.Int(port)
|
||||
if len(this.ServerAddr) > 0 {
|
||||
_, port, err := net.SplitHostPort(this.ServerAddr)
|
||||
if err == nil && len(port) > 0 {
|
||||
return types.Int(port)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
|
||||
var host = this.RawReq.Host
|
||||
if len(host) > 0 {
|
||||
_, port, err := net.SplitHostPort(host)
|
||||
if err == nil && len(port) > 0 {
|
||||
return types.Int(port)
|
||||
}
|
||||
}
|
||||
|
||||
if this.IsHTTP {
|
||||
return 80
|
||||
}
|
||||
return 443
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) Id() string {
|
||||
|
||||
@@ -43,7 +43,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
if uriChanged {
|
||||
this.uri = newURI
|
||||
}
|
||||
this.tags = append(this.tags, ref.AuthPolicy.Type)
|
||||
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
|
||||
return
|
||||
} else {
|
||||
// Basic Auth比较特殊
|
||||
@@ -64,7 +64,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
}
|
||||
}
|
||||
this.writer.WriteHeader(http.StatusUnauthorized)
|
||||
this.tags = append(this.tags, ref.AuthPolicy.Type)
|
||||
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
16
internal/nodes/http_request_hls.go
Normal file
16
internal/nodes/http_request_hls.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
|
||||
// stub
|
||||
return nil
|
||||
}
|
||||
@@ -34,60 +34,63 @@ func (this *HTTPRequest) doMismatch() {
|
||||
}
|
||||
|
||||
// 根据配置进行相应的处理
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
var statusCode = 404
|
||||
var httpAllConfig = globalServerConfig.HTTPAll
|
||||
var mismatchAction = httpAllConfig.DomainMismatchAction
|
||||
var nodeConfig = sharedNodeConfig // copy
|
||||
if nodeConfig != nil {
|
||||
var globalServerConfig = nodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
var statusCode = 404
|
||||
var httpAllConfig = globalServerConfig.HTTPAll
|
||||
var mismatchAction = httpAllConfig.DomainMismatchAction
|
||||
|
||||
if mismatchAction != nil && mismatchAction.Options != nil {
|
||||
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
|
||||
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
|
||||
statusCode = mismatchStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
// 是否正在访问IP
|
||||
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
|
||||
this.writer.statusCode = statusCode
|
||||
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(contentHTML)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查cc
|
||||
// TODO 可以在管理端配置是否开启以及最多尝试次数
|
||||
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
|
||||
if len(remoteIP) > 0 {
|
||||
const maxAttempts = 100
|
||||
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
|
||||
// 在加入之前再次检查黑名单
|
||||
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
|
||||
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
|
||||
if mismatchAction != nil && mismatchAction.Options != nil {
|
||||
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
|
||||
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
|
||||
statusCode = mismatchStatusCode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理当前连接
|
||||
if mismatchAction != nil && mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
|
||||
if mismatchAction.Options != nil {
|
||||
// 是否正在访问IP
|
||||
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
|
||||
this.writer.statusCode = statusCode
|
||||
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
|
||||
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(contentHTML))
|
||||
_, _ = this.writer.WriteString(contentHTML)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查cc
|
||||
// TODO 可以在管理端配置是否开启以及最多尝试次数
|
||||
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
|
||||
if len(remoteIP) > 0 {
|
||||
const maxAttempts = 100
|
||||
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
|
||||
// 在加入之前再次检查黑名单
|
||||
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
|
||||
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理当前连接
|
||||
if mismatchAction != nil && mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
|
||||
if mismatchAction.Options != nil {
|
||||
this.writer.statusCode = statusCode
|
||||
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(contentHTML))
|
||||
} else {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
this.Close()
|
||||
return
|
||||
}
|
||||
return
|
||||
} else {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
this.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
|
||||
}
|
||||
return true
|
||||
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
|
||||
var newURL = page.URL
|
||||
var newURL = this.Format(page.URL)
|
||||
if len(newURL) == 0 {
|
||||
newURL = "/"
|
||||
}
|
||||
|
||||
10
internal/nodes/http_request_plan_before.go
Normal file
10
internal/nodes/http_request_plan_before.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
// 检查套餐
|
||||
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 套餐过期
|
||||
func (this *HTTPRequest) doPlanExpires() {
|
||||
this.tags = append(this.tags, "plan")
|
||||
|
||||
var statusCode = http.StatusNotFound
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))
|
||||
}
|
||||
@@ -381,11 +381,20 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
return
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
return
|
||||
}
|
||||
|
||||
// fix Content-Type
|
||||
if resp.Header["Content-Type"] == nil {
|
||||
resp.Header["Content-Type"] = []string{}
|
||||
}
|
||||
|
||||
// 40x && 50x
|
||||
*failStatusCode = resp.StatusCode
|
||||
if resp != nil &&
|
||||
((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
|
||||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
|
||||
if ((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
|
||||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
|
||||
(originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) &&
|
||||
!isLastRetry {
|
||||
if resp.Body != nil {
|
||||
@@ -397,8 +406,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
}
|
||||
|
||||
// 尝试从缓存中恢复
|
||||
if resp != nil &&
|
||||
resp.StatusCode >= 500 && // support 50X only
|
||||
if resp.StatusCode >= 500 && // support 50X only
|
||||
resp.StatusCode < 510 &&
|
||||
this.cacheCanTryStale &&
|
||||
this.web.Cache.Stale != nil &&
|
||||
@@ -436,19 +444,41 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
|
||||
err := this.web.Optimization.FilterResponse(this.URL(), resp)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, "Page Optimization: Fail to read content from origin", "内容优化:从源站读取内容失败", false)
|
||||
this.write50x(err, http.StatusBadGateway, "Page Optimization: fail to read content from origin", "内容优化:从源站读取内容失败", false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// HLS
|
||||
if this.web.HLS != nil &&
|
||||
this.web.HLS.Encrypting != nil &&
|
||||
this.web.HLS.Encrypting.IsOn &&
|
||||
resp.StatusCode == http.StatusOK {
|
||||
m3u8Err := this.processM3u8Response(resp)
|
||||
if m3u8Err != nil {
|
||||
this.write50x(m3u8Err, http.StatusBadGateway, "m3u8 encrypt: fail to read content from origin", "m3u8加密:从源站读取内容失败", false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 设置Charset
|
||||
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
||||
// TODO 这里应该可以设置文本类型的列表
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
contentTypes, ok := resp.Header["Content-Type"]
|
||||
if ok && len(contentTypes) > 0 {
|
||||
var contentType = contentTypes[0]
|
||||
if this.web.Charset.Force {
|
||||
var semiIndex = strings.Index(contentType, ";")
|
||||
if semiIndex > 0 {
|
||||
contentType = contentType[:semiIndex]
|
||||
}
|
||||
}
|
||||
if _, found := textMimeMap[contentType]; found {
|
||||
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
|
||||
var newCharset = this.web.Charset.Charset
|
||||
if this.web.Charset.IsUpper {
|
||||
newCharset = strings.ToUpper(newCharset)
|
||||
}
|
||||
resp.Header["Content-Type"][0] = contentType + "; charset=" + newCharset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
@@ -10,26 +11,45 @@ import (
|
||||
func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
{
|
||||
req := &HTTPRequest{
|
||||
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: NewEmptyResponseWriter(nil),
|
||||
ReqServer: &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.init()
|
||||
req.Do()
|
||||
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
||||
}
|
||||
{
|
||||
req := &HTTPRequest{
|
||||
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: NewEmptyResponseWriter(nil),
|
||||
ReqServer: &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{
|
||||
IsOn: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.init()
|
||||
req.Do()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,19 @@ import (
|
||||
)
|
||||
|
||||
// 流量限制
|
||||
func (this *HTTPRequest) doTrafficLimit() {
|
||||
func (this *HTTPRequest) doTrafficLimit(status *serverconfigs.TrafficLimitStatus) (blocked bool) {
|
||||
if status == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果是网站单独设置的流量限制,则检查是否已关闭
|
||||
var config = this.ReqServer.TrafficLimit
|
||||
if (config == nil || !config.IsOn) && status.PlanId == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果是套餐设置的流量限制,即使套餐变更了(变更套餐或者变更套餐的限制),仍然会提示流量超限
|
||||
|
||||
this.tags = append(this.tags, "trafficLimit")
|
||||
|
||||
var statusCode = 509
|
||||
@@ -17,10 +29,19 @@ func (this *HTTPRequest) doTrafficLimit() {
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
var config = this.ReqServer.TrafficLimit
|
||||
// check plan traffic limit
|
||||
if (config == nil || !config.IsOn) && this.ReqServer.PlanId() > 0 && this.nodeConfig != nil {
|
||||
var planConfig = this.nodeConfig.FindPlan(this.ReqServer.PlanId())
|
||||
if planConfig != nil && planConfig.TrafficLimit != nil && planConfig.TrafficLimit.IsOn {
|
||||
config = planConfig.TrafficLimit
|
||||
}
|
||||
}
|
||||
|
||||
if config != nil && len(config.NoticePageBody) != 0 {
|
||||
_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
|
||||
} else {
|
||||
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -4,7 +4,13 @@
|
||||
|
||||
package nodes
|
||||
|
||||
// UAM
|
||||
func (this *HTTPRequest) doUAM() (block bool) {
|
||||
func (this *HTTPRequest) isUAMRequest() bool {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
|
||||
// UAM
|
||||
func (this *HTTPRequest) doUAM() (block bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -67,8 +67,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
|
||||
// 当前服务的独立设置
|
||||
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||
blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
|
||||
if blocked {
|
||||
blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
@@ -78,8 +78,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
|
||||
// 公用的防火墙设置
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
|
||||
blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blocked {
|
||||
blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
@@ -266,8 +266,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
return
|
||||
}
|
||||
|
||||
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
|
||||
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
|
||||
result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
|
||||
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
|
||||
breakChecking = true
|
||||
}
|
||||
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
if err != nil {
|
||||
@@ -277,28 +280,28 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if result.Set != nil {
|
||||
if forceLog {
|
||||
this.forceLog = true
|
||||
}
|
||||
|
||||
if ruleSet.HasSpecialActions() {
|
||||
if result.Set.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
this.firewallRuleGroupId = types.Int64(result.Group.Id)
|
||||
this.firewallRuleSetId = types.Int64(result.Set.Id)
|
||||
|
||||
if ruleSet.HasAttackActions() {
|
||||
if result.Set.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
|
||||
}
|
||||
|
||||
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
|
||||
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
|
||||
}
|
||||
|
||||
return !goNext, false
|
||||
return !result.GoNext, breakChecking
|
||||
}
|
||||
|
||||
// call response waf
|
||||
@@ -316,23 +319,26 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
|
||||
}
|
||||
|
||||
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||
blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
|
||||
if blocked {
|
||||
blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
|
||||
blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blocked {
|
||||
blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool) {
|
||||
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
|
||||
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
|
||||
return
|
||||
}
|
||||
@@ -347,8 +353,11 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
return
|
||||
}
|
||||
|
||||
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
|
||||
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
|
||||
result, err := w.MatchResponse(this, resp, this.writer)
|
||||
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
|
||||
breakChecking = true
|
||||
}
|
||||
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
if err != nil {
|
||||
@@ -358,28 +367,28 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if result.Set != nil {
|
||||
if forceLog {
|
||||
this.forceLog = true
|
||||
}
|
||||
|
||||
if ruleSet.HasSpecialActions() {
|
||||
if result.Set.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
this.firewallRuleGroupId = types.Int64(result.Group.Id)
|
||||
this.firewallRuleSetId = types.Int64(result.Set.Id)
|
||||
|
||||
if ruleSet.HasAttackActions() {
|
||||
if result.Set.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
|
||||
}
|
||||
|
||||
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
|
||||
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
|
||||
}
|
||||
|
||||
return !goNext
|
||||
return !result.GoNext, breakChecking
|
||||
}
|
||||
|
||||
// WAFRaw 原始请求
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
setutils "github.com/TeaOSLab/EdgeNode/internal/utils/sets"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/writers"
|
||||
_ "github.com/biessek/golang-ico"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gowebp"
|
||||
_ "golang.org/x/image/bmp"
|
||||
@@ -340,7 +339,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
|
||||
|
||||
cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial)
|
||||
if err != nil {
|
||||
if err == caches.ErrEntityTooLarge && addStatusHeader {
|
||||
if errors.Is(err, caches.ErrEntityTooLarge) && addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, entity too large")
|
||||
}
|
||||
|
||||
@@ -494,7 +493,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter)
|
||||
var cacheReader = readers.NewTeeReaderCloser(resp.Body, this.cacheWriter, false)
|
||||
resp.Body = cacheReader
|
||||
this.rawReader = cacheReader
|
||||
|
||||
@@ -564,18 +563,18 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
|
||||
}
|
||||
|
||||
var contentEncoding = this.GetHeader("Content-Encoding")
|
||||
switch contentEncoding {
|
||||
case "gzip", "deflate", "br", "zstd":
|
||||
reader, err := compressions.NewReader(resp.Body, contentEncoding)
|
||||
if err != nil {
|
||||
if len(contentEncoding) > 0 {
|
||||
if compressions.SupportEncoding(contentEncoding) {
|
||||
reader, err := compressions.NewReader(resp.Body, contentEncoding)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
this.Header().Del("Content-Encoding")
|
||||
this.Header().Del("Content-Length")
|
||||
this.rawReader = reader
|
||||
} else {
|
||||
return
|
||||
}
|
||||
this.Header().Del("Content-Encoding")
|
||||
this.Header().Del("Content-Length")
|
||||
this.rawReader = reader
|
||||
case "": // 空
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
this.webpOriginContentType = contentType
|
||||
@@ -603,7 +602,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
|
||||
var contentEncoding = this.GetHeader("Content-Encoding")
|
||||
|
||||
if this.compressionConfig == nil || !this.compressionConfig.IsOn {
|
||||
if lists.ContainsString([]string{"gzip", "deflate", "br", "zstd"}, contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) {
|
||||
if compressions.SupportEncoding(contentEncoding) && !httpAcceptEncoding(acceptEncodings, contentEncoding) {
|
||||
reader, err := compressions.NewReader(resp.Body, contentEncoding)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -625,7 +624,7 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
|
||||
}
|
||||
|
||||
// 如果已经有编码则不处理
|
||||
if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !lists.ContainsString([]string{"gzip", "deflate", "br", "zstd"}, contentEncoding)) {
|
||||
if len(contentEncoding) > 0 && (!this.compressionConfig.DecompressData || !compressions.SupportEncoding(contentEncoding)) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -807,6 +806,8 @@ func (this *HTTPWriter) Write(data []byte) (n int, err error) {
|
||||
}
|
||||
n, err = this.writer.Write(data)
|
||||
|
||||
this.checkPlanBandwidth(n)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -967,6 +968,7 @@ func (this *HTTPWriter) Close() {
|
||||
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
hijack, ok := this.rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
this.req.isHijacked = true
|
||||
return hijack.Hijack()
|
||||
}
|
||||
return
|
||||
|
||||
@@ -11,3 +11,7 @@ import (
|
||||
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) checkPlanBandwidth(n int) {
|
||||
// stub
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
@@ -52,6 +53,8 @@ func (this *HTTPListener) Serve() error {
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
default:
|
||||
// do nothing
|
||||
}
|
||||
},
|
||||
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
|
||||
@@ -74,7 +77,7 @@ func (this *HTTPListener) Serve() error {
|
||||
// HTTP协议
|
||||
if this.isHTTP {
|
||||
err := this.httpServer.Serve(this.Listener)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -84,7 +87,7 @@ func (this *HTTPListener) Serve() error {
|
||||
this.httpServer.TLSConfig = this.buildTLSConfig()
|
||||
|
||||
err := this.httpServer.ServeTLS(this.Listener, "", "")
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -105,8 +108,13 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.Reset()
|
||||
}
|
||||
|
||||
// ServerHTTP 处理HTTP请求
|
||||
// ServeHTTPWithAddr 处理HTTP请求
|
||||
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
this.ServeHTTPWithAddr(rawWriter, rawReq, this.addr)
|
||||
}
|
||||
|
||||
// ServeHTTPWithAddr 处理HTTP请求并指定服务地址
|
||||
func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawReq *http.Request, serverAddr string) {
|
||||
if len(rawReq.Host) > 253 {
|
||||
http.Error(rawWriter, "Host too long.", http.StatusBadRequest)
|
||||
time.Sleep(1 * time.Second) // make connection slow down
|
||||
@@ -175,10 +183,12 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
}
|
||||
|
||||
// 绑定连接
|
||||
var clientConn ClientConnInterface
|
||||
if server != nil && server.Id > 0 {
|
||||
var requestConn = rawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn != nil {
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
var ok bool
|
||||
clientConn, ok = requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
@@ -211,7 +221,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
ReqServer: server,
|
||||
ReqHost: reqHost,
|
||||
ServerName: serverName,
|
||||
ServerAddr: this.addr,
|
||||
ServerAddr: serverAddr,
|
||||
IsHTTP: this.isHTTP,
|
||||
IsHTTPS: this.isHTTPS,
|
||||
IsHTTP3: this.isHTTP3,
|
||||
@@ -219,6 +229,14 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
nodeConfig: sharedNodeConfig,
|
||||
}
|
||||
req.Do()
|
||||
|
||||
// fix hijacked connection state
|
||||
if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil {
|
||||
netConn, ok := clientConn.(net.Conn)
|
||||
if ok {
|
||||
this.httpServer.ConnState(netConn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查host是否为IP
|
||||
|
||||
@@ -784,9 +784,13 @@ func (this *Node) listenSock() error {
|
||||
var costSeconds = time.Since(before).Seconds()
|
||||
var gcStats = &debug.GCStats{}
|
||||
debug.ReadGCStats(gcStats)
|
||||
var pauseMS float64
|
||||
if len(gcStats.Pause) > 0 {
|
||||
pauseMS = gcStats.Pause[0].Seconds() * 1000
|
||||
}
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]any{
|
||||
"pauseMS": gcStats.PauseTotal.Seconds() * 1000,
|
||||
"pauseMS": pauseMS,
|
||||
"costMS": costSeconds * 1000,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -102,6 +102,8 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
|
||||
err = this.execNetworkSecurityPolicyChangedTask(rpcClient)
|
||||
case "webPPolicyChanged":
|
||||
err = this.execWebPPolicyChangedTask(rpcClient)
|
||||
case "planChanged":
|
||||
err = this.execPlanChangedTask(rpcClient)
|
||||
default:
|
||||
// 特殊任务
|
||||
if strings.HasPrefix(task.Type, "ipListDeleted") { // 删除IP名单
|
||||
|
||||
@@ -34,3 +34,7 @@ func (this *Node) execNetworkSecurityPolicyChangedTask(rpcClient *rpc.RPCClient)
|
||||
// stub
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) execPlanChangedTask(rpcClient *rpc.RPCClient) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNode_Start(t *testing.T) {
|
||||
node := NewNode()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var node = NewNode()
|
||||
node.Start()
|
||||
}
|
||||
|
||||
func TestNode_Test(t *testing.T) {
|
||||
node := NewNode()
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var node = NewNode()
|
||||
err := node.Test()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -3,11 +3,16 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgradeManager_install(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
err := NewUpgradeManager().install()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -55,6 +55,7 @@ type RPCClient struct {
|
||||
ClientAgentIPRPC pb.ClientAgentIPServiceClient
|
||||
AuthorityKeyRPC pb.AuthorityKeyServiceClient
|
||||
UpdatingServerListRPC pb.UpdatingServerListServiceClient
|
||||
PlanRPC pb.PlanServiceClient
|
||||
}
|
||||
|
||||
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
@@ -91,6 +92,7 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
client.ClientAgentIPRPC = pb.NewClientAgentIPServiceClient(client)
|
||||
client.AuthorityKeyRPC = pb.NewAuthorityKeyServiceClient(client)
|
||||
client.UpdatingServerListRPC = pb.NewUpdatingServerListServiceClient(client)
|
||||
client.PlanRPC = pb.NewPlanServiceClient(client)
|
||||
|
||||
err := client.init()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ package rpc_test
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"sync"
|
||||
@@ -13,6 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func TestRPCConcurrentCall(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -43,6 +48,10 @@ func TestRPCConcurrentCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRPC_Retry(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -3,38 +3,51 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fnv"
|
||||
syncutils "github.com/TeaOSLab/EdgeNode/internal/utils/sync"
|
||||
"github.com/mssola/useragent"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedUserAgentParser = NewUserAgentParser()
|
||||
|
||||
const userAgentShardingCount = 8
|
||||
|
||||
// UserAgentParser UserAgent解析器
|
||||
type UserAgentParser struct {
|
||||
parser *useragent.UserAgent
|
||||
cacheMaps [userAgentShardingCount]map[uint64]UserAgentParserResult
|
||||
pool *sync.Pool
|
||||
mu *syncutils.RWMutex
|
||||
|
||||
cacheMap1 map[uint64]UserAgentParserResult
|
||||
cacheMap2 map[uint64]UserAgentParserResult
|
||||
maxCacheItems int
|
||||
|
||||
cacheCursor int
|
||||
locker sync.RWMutex
|
||||
gcTicker *time.Ticker
|
||||
gcIndex int
|
||||
}
|
||||
|
||||
// NewUserAgentParser 获取新解析器
|
||||
func NewUserAgentParser() *UserAgentParser {
|
||||
var parser = &UserAgentParser{
|
||||
parser: &useragent.UserAgent{},
|
||||
cacheMap1: map[uint64]UserAgentParserResult{},
|
||||
cacheMap2: map[uint64]UserAgentParserResult{},
|
||||
cacheCursor: 0,
|
||||
pool: &sync.Pool{
|
||||
New: func() any {
|
||||
return &useragent.UserAgent{}
|
||||
},
|
||||
},
|
||||
cacheMaps: [userAgentShardingCount]map[uint64]UserAgentParserResult{},
|
||||
mu: syncutils.NewRWMutex(userAgentShardingCount),
|
||||
}
|
||||
|
||||
for i := 0; i < userAgentShardingCount; i++ {
|
||||
parser.cacheMaps[i] = map[uint64]UserAgentParserResult{}
|
||||
}
|
||||
|
||||
parser.init()
|
||||
return parser
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *UserAgentParser) init() {
|
||||
var maxCacheItems = 10_000
|
||||
var systemMemory = utils.SystemMemoryGB()
|
||||
@@ -46,8 +59,16 @@ func (this *UserAgentParser) init() {
|
||||
maxCacheItems = 20_000
|
||||
}
|
||||
this.maxCacheItems = maxCacheItems
|
||||
|
||||
this.gcTicker = time.NewTicker(5 * time.Second)
|
||||
goman.New(func() {
|
||||
for range this.gcTicker.C {
|
||||
this.GC()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Parse 解析UserAgent
|
||||
func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResult) {
|
||||
// 限制长度
|
||||
if len(userAgent) == 0 || len(userAgent) > 256 {
|
||||
@@ -55,28 +76,22 @@ func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResu
|
||||
}
|
||||
|
||||
var userAgentKey = fnv.HashString(userAgent)
|
||||
var shardingIndex = int(userAgentKey % userAgentShardingCount)
|
||||
|
||||
this.locker.RLock()
|
||||
cacheResult, ok := this.cacheMap1[userAgentKey]
|
||||
this.mu.RLock(shardingIndex)
|
||||
cacheResult, ok := this.cacheMaps[shardingIndex][userAgentKey]
|
||||
if ok {
|
||||
this.locker.RUnlock()
|
||||
this.mu.RUnlock(shardingIndex)
|
||||
return cacheResult
|
||||
}
|
||||
this.mu.RUnlock(shardingIndex)
|
||||
|
||||
cacheResult, ok = this.cacheMap2[userAgentKey]
|
||||
if ok {
|
||||
this.locker.RUnlock()
|
||||
return cacheResult
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
this.parser.Parse(userAgent)
|
||||
result.OS = this.parser.OSInfo()
|
||||
result.BrowserName, result.BrowserVersion = this.parser.Browser()
|
||||
result.IsMobile = this.parser.Mobile()
|
||||
var parser = this.pool.Get().(*useragent.UserAgent)
|
||||
parser.Parse(userAgent)
|
||||
result.OS = parser.OSInfo()
|
||||
result.BrowserName, result.BrowserVersion = parser.Browser()
|
||||
result.IsMobile = parser.Mobile()
|
||||
this.pool.Put(parser)
|
||||
|
||||
// 忽略特殊字符
|
||||
if len(result.BrowserName) > 0 {
|
||||
@@ -87,19 +102,45 @@ func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResu
|
||||
}
|
||||
}
|
||||
|
||||
if this.cacheCursor == 0 {
|
||||
this.cacheMap1[userAgentKey] = result
|
||||
if len(this.cacheMap1) >= this.maxCacheItems {
|
||||
this.cacheCursor = 1
|
||||
this.cacheMap2 = map[uint64]UserAgentParserResult{}
|
||||
}
|
||||
} else {
|
||||
this.cacheMap2[userAgentKey] = result
|
||||
if len(this.cacheMap2) >= this.maxCacheItems {
|
||||
this.cacheCursor = 0
|
||||
this.cacheMap1 = map[uint64]UserAgentParserResult{}
|
||||
}
|
||||
}
|
||||
this.mu.Lock(shardingIndex)
|
||||
this.cacheMaps[shardingIndex][userAgentKey] = result
|
||||
this.mu.Unlock(shardingIndex)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MaxCacheItems 读取能容纳的缓存最大数量
|
||||
func (this *UserAgentParser) MaxCacheItems() int {
|
||||
return this.maxCacheItems
|
||||
}
|
||||
|
||||
// Len 读取当前缓存数量
|
||||
func (this *UserAgentParser) Len() int {
|
||||
var total = 0
|
||||
for i := 0; i < userAgentShardingCount; i++ {
|
||||
this.mu.RLock(i)
|
||||
total += len(this.cacheMaps[i])
|
||||
this.mu.RUnlock(i)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// GC 回收多余的缓存
|
||||
func (this *UserAgentParser) GC() {
|
||||
var total = this.Len()
|
||||
if total > this.maxCacheItems {
|
||||
for {
|
||||
var shardingIndex = this.gcIndex
|
||||
|
||||
this.mu.Lock(shardingIndex)
|
||||
total -= len(this.cacheMaps[shardingIndex])
|
||||
this.cacheMaps[shardingIndex] = map[uint64]UserAgentParserResult{}
|
||||
this.gcIndex = (this.gcIndex + 1) % userAgentShardingCount
|
||||
this.mu.Unlock(shardingIndex)
|
||||
|
||||
if total <= this.maxCacheItems {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package stats
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserAgentParser_Parse(t *testing.T) {
|
||||
var parser = NewUserAgentParser()
|
||||
var parser = stats.NewUserAgentParser()
|
||||
for i := 0; i < 4; i++ {
|
||||
t.Log(parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/1"))
|
||||
t.Log(parser.Parse("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"))
|
||||
@@ -19,7 +23,7 @@ func TestUserAgentParser_Parse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUserAgentParser_Parse_Unknown(t *testing.T) {
|
||||
var parser = NewUserAgentParser()
|
||||
var parser = stats.NewUserAgentParser()
|
||||
t.Log(parser.Parse("Mozilla/5.0 (Wind 10.0; WOW64; rv:49.0) Apple/537.36 (KHTML, like Gecko) Chr/88.0.4324.96 Sa/537.36 Test/1"))
|
||||
t.Log(parser.Parse(""))
|
||||
}
|
||||
@@ -28,10 +32,10 @@ func TestUserAgentParser_Memory(t *testing.T) {
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
var parser = NewUserAgentParser()
|
||||
var parser = stats.NewUserAgentParser()
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
@@ -40,32 +44,76 @@ func TestUserAgentParser_Memory(t *testing.T) {
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
|
||||
t.Log("max cache items:", parser.maxCacheItems)
|
||||
t.Log("cache1:", len(parser.cacheMap1), "cache2:", len(parser.cacheMap2), "cache3:", (stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB")
|
||||
t.Log("max cache items:", parser.MaxCacheItems())
|
||||
t.Log("cache:", parser.Len(), "usage:", (stat2.HeapInuse-stat1.HeapInuse)>>20, "MB")
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse(b *testing.B) {
|
||||
var parser = NewUserAgentParser()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
|
||||
func TestNewUserAgentParser_GC(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse2(b *testing.B) {
|
||||
var parser = NewUserAgentParser()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var parser = stats.NewUserAgentParser()
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
|
||||
}
|
||||
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
|
||||
|
||||
time.Sleep(60 * time.Second) // wait to gc
|
||||
t.Log(parser.Len(), "cache items")
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse3(b *testing.B) {
|
||||
var parser = NewUserAgentParser()
|
||||
func TestNewUserAgentParser_Mobile(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
var parser = stats.NewUserAgentParser()
|
||||
for _, userAgent := range []string{
|
||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148",
|
||||
"Mozilla/5.0 (Linux; U; Android 2.2.1; en-us; Nexus One Build/FRG83) AppleWebKit/533.1 (KHTML, like Gecko) Version/4.0 Mobile Safari/533.1",
|
||||
} {
|
||||
a.IsTrue(parser.Parse(userAgent).IsMobile)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse_Many_LimitCPU(b *testing.B) {
|
||||
runtime.GOMAXPROCS(4)
|
||||
|
||||
var parser = stats.NewUserAgentParser()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
|
||||
}
|
||||
})
|
||||
b.Log(parser.Len())
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse_Many(b *testing.B) {
|
||||
var parser = stats.NewUserAgentParser()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 1_000_000)))
|
||||
}
|
||||
})
|
||||
b.Log(parser.Len())
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse_Few_LimitCPU(b *testing.B) {
|
||||
runtime.GOMAXPROCS(4)
|
||||
|
||||
var parser = stats.NewUserAgentParser()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
|
||||
}
|
||||
})
|
||||
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
|
||||
b.Log(parser.Len())
|
||||
}
|
||||
|
||||
func BenchmarkUserAgentParser_Parse_Few(b *testing.B) {
|
||||
var parser = stats.NewUserAgentParser()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
|
||||
}
|
||||
})
|
||||
b.Log(parser.Len())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ttlcache
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -48,16 +50,34 @@ func TestCache_Memory(t *testing.T) {
|
||||
}
|
||||
|
||||
var cache = NewCache[int]()
|
||||
var isReady bool
|
||||
|
||||
testutils.StartMemoryStats(t, func() {
|
||||
if !isReady {
|
||||
return
|
||||
}
|
||||
t.Log(cache.Count(), "items")
|
||||
})
|
||||
|
||||
var count = 20_000_000
|
||||
var count = 1_000_000
|
||||
if utils.SystemMemoryGB() > 4 {
|
||||
count = 20_000_000
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
cache.Write("a"+strconv.Itoa(i), 1, time.Now().Unix()+int64(rands.Int(0, 300)))
|
||||
}
|
||||
|
||||
func() {
|
||||
var before = time.Now()
|
||||
runtime.GC()
|
||||
var costSeconds = time.Since(before).Seconds()
|
||||
var stats = &debug.GCStats{}
|
||||
debug.ReadGCStats(stats)
|
||||
t.Log("GC pause:", stats.Pause[0].Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
|
||||
}()
|
||||
|
||||
isReady = true
|
||||
|
||||
t.Log(cache.Count())
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
@@ -105,6 +125,10 @@ func TestCache_IncreaseInt64(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCache_Read(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var cache = NewCache[int](PiecesOption{Count: 32})
|
||||
|
||||
@@ -4,12 +4,17 @@ package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var db = agents.NewDB(Tea.Root + "/data/agents.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func TestParseQueue_Process(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var queue = agents.NewQueue()
|
||||
go queue.Start()
|
||||
time.Sleep(1 * time.Second)
|
||||
@@ -19,6 +24,10 @@ func TestParseQueue_Process(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseQueue_ParseIP(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var queue = agents.NewQueue()
|
||||
for _, ip := range []string{
|
||||
"192.168.1.100",
|
||||
|
||||
@@ -4,9 +4,14 @@ package clock_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/clock"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadServer(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
t.Log(clock.NewClockManager().ReadServer("pool.ntp.org"))
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ type SupportedUIntType interface {
|
||||
type Counter[T SupportedUIntType] struct {
|
||||
countMaps uint64
|
||||
locker *syncutils.RWMutex
|
||||
itemMaps []map[uint64]*Item[T]
|
||||
itemMaps []map[uint64]Item[T]
|
||||
|
||||
gcTicker *time.Ticker
|
||||
gcIndex int
|
||||
@@ -36,9 +36,9 @@ func NewCounter[T SupportedUIntType]() *Counter[T] {
|
||||
count = 8
|
||||
}
|
||||
|
||||
var itemMaps = []map[uint64]*Item[T]{}
|
||||
var itemMaps = []map[uint64]Item[T]{}
|
||||
for i := 0; i < count; i++ {
|
||||
itemMaps = append(itemMaps, map[uint64]*Item[T]{})
|
||||
itemMaps = append(itemMaps, map[uint64]Item[T]{})
|
||||
}
|
||||
|
||||
var counter = &Counter[T]{
|
||||
@@ -69,19 +69,22 @@ func (this *Counter[T]) WithGC() *Counter[T] {
|
||||
func (this *Counter[T]) Increase(key uint64, lifeSeconds int) T {
|
||||
var index = int(key % this.countMaps)
|
||||
this.locker.RLock(index)
|
||||
var item = this.itemMaps[index][key]
|
||||
var item = this.itemMaps[index][key] // item MUST NOT be pointer
|
||||
this.locker.RUnlock(index)
|
||||
if item == nil {
|
||||
if !item.IsOk() {
|
||||
// no need to care about duplication
|
||||
// always insert new item even when itemMap is full
|
||||
item = NewItem[T](lifeSeconds)
|
||||
var result = item.Increase()
|
||||
this.locker.Lock(index)
|
||||
this.itemMaps[index][key] = item
|
||||
this.locker.Unlock(index)
|
||||
return result
|
||||
}
|
||||
|
||||
this.locker.Lock(index)
|
||||
var result = item.Increase()
|
||||
this.itemMaps[index][key] = item // overwrite
|
||||
this.locker.Unlock(index)
|
||||
return result
|
||||
}
|
||||
@@ -97,7 +100,7 @@ func (this *Counter[T]) Get(key uint64) T {
|
||||
this.locker.RLock(index)
|
||||
defer this.locker.RUnlock(index)
|
||||
var item = this.itemMaps[index][key]
|
||||
if item != nil {
|
||||
if item.IsOk() {
|
||||
return item.Sum()
|
||||
}
|
||||
return 0
|
||||
@@ -115,7 +118,7 @@ func (this *Counter[T]) Reset(key uint64) {
|
||||
var item = this.itemMaps[index][key]
|
||||
this.locker.RUnlock(index)
|
||||
|
||||
if item != nil {
|
||||
if item.IsOk() {
|
||||
this.locker.Lock(index)
|
||||
delete(this.itemMaps[index], key)
|
||||
this.locker.Unlock(index)
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestCounterMemory(t *testing.T) {
|
||||
var costSeconds = time.Since(before).Seconds()
|
||||
var stats = &debug.GCStats{}
|
||||
debug.ReadGCStats(stats)
|
||||
t.Log("GC pause:", stats.PauseTotal.Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
|
||||
t.Log("GC pause:", stats.Pause[0].Seconds()*1000, "ms", "cost:", costSeconds*1000, "ms")
|
||||
}
|
||||
|
||||
gcPause()
|
||||
@@ -113,12 +113,14 @@ func BenchmarkCounter_Increase(b *testing.B) {
|
||||
runtime.GOMAXPROCS(4)
|
||||
|
||||
var counter = counters.NewCounter[uint32]()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
var i uint64
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
counter.Increase(atomic.AddUint64(&i, 1)%1e6, 20)
|
||||
counter.Increase(atomic.AddUint64(&i, 1)%1_000_000, 20)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -138,11 +140,12 @@ func BenchmarkCounter_IncreaseKey(b *testing.B) {
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
var i uint64
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
counter.IncreaseKey(types.String(atomic.AddUint64(&i, 1)%1e6), 20)
|
||||
counter.IncreaseKey(types.String(atomic.AddUint64(&i, 1)%1_000_000), 20)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ type Item[T SupportedUIntType] struct {
|
||||
spanSeconds int64
|
||||
}
|
||||
|
||||
func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
|
||||
func NewItem[T SupportedUIntType](lifeSeconds int) Item[T] {
|
||||
if lifeSeconds <= 0 {
|
||||
lifeSeconds = 60
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
|
||||
spanSeconds++
|
||||
}
|
||||
|
||||
return &Item[T]{
|
||||
return Item[T]{
|
||||
lifeSeconds: int64(lifeSeconds),
|
||||
spanSeconds: int64(spanSeconds),
|
||||
lastUpdateTime: fasttime.Now().Unix(),
|
||||
@@ -126,3 +126,7 @@ func (this *Item[T]) calculateSpanIndex(timestamp int64) int {
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
func (this *Item[T]) IsOk() bool {
|
||||
return this.lifeSeconds > 0
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func TestIsIPv4(t *testing.T) {
|
||||
func TestIsIPv6(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(utils.IsIPv6("192.168.1.1"))
|
||||
a.IsFloat32(utils.IsIPv6("0.0.0.0"))
|
||||
a.IsFalse(utils.IsIPv6("0.0.0.0"))
|
||||
a.IsFalse(utils.IsIPv6("192.168.1.256"))
|
||||
a.IsFalse(utils.IsIPv6("192.168.1"))
|
||||
a.IsTrue(utils.IsIPv6("::1"))
|
||||
a.IsTrue(utils.IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||
a.IsTrue(utils.IsIPv4("::ffff:192.168.0.1"))
|
||||
a.IsFalse(utils.IsIPv4("::ffff:192.168.0.1"))
|
||||
a.IsTrue(utils.IsIPv6("::ffff:192.168.0.1"))
|
||||
}
|
||||
|
||||
|
||||
61
internal/utils/ratelimit/bandwidth.go
Normal file
61
internal/utils/ratelimit/bandwidth.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Bandwidth lossy bandwidth limiter
|
||||
type Bandwidth struct {
|
||||
totalBytes int64
|
||||
|
||||
currentTimestamp int64
|
||||
currentBytes int64
|
||||
}
|
||||
|
||||
// NewBandwidth create new bandwidth limiter
|
||||
func NewBandwidth(totalBytes int64) *Bandwidth {
|
||||
return &Bandwidth{totalBytes: totalBytes}
|
||||
}
|
||||
|
||||
// Ack acquire next chance to send data
|
||||
func (this *Bandwidth) Ack(ctx context.Context, newBytes int) {
|
||||
if newBytes <= 0 {
|
||||
return
|
||||
}
|
||||
if this.totalBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var timestamp = fasttime.Now().Unix()
|
||||
if this.currentTimestamp != 0 && this.currentTimestamp != timestamp {
|
||||
this.currentTimestamp = timestamp
|
||||
this.currentBytes = int64(newBytes)
|
||||
|
||||
// 第一次发送直接放行,不需要判断
|
||||
return
|
||||
}
|
||||
|
||||
if this.currentTimestamp == 0 {
|
||||
this.currentTimestamp = timestamp
|
||||
}
|
||||
if atomic.AddInt64(&this.currentBytes, int64(newBytes)) <= this.totalBytes {
|
||||
return
|
||||
}
|
||||
|
||||
var timeout = time.NewTimer(1 * time.Second)
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
27
internal/utils/ratelimit/bandwidth_test.go
Normal file
27
internal/utils/ratelimit/bandwidth_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package ratelimit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBandwidth(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var bandwidth = ratelimit.NewBandwidth(32 << 10)
|
||||
bandwidth.Ack(context.Background(), 123)
|
||||
bandwidth.Ack(context.Background(), 16 << 10)
|
||||
bandwidth.Ack(context.Background(), 32 << 10)
|
||||
}
|
||||
|
||||
func TestBandwidth_0(t *testing.T) {
|
||||
var bandwidth = ratelimit.NewBandwidth(0)
|
||||
bandwidth.Ack(context.Background(), 123)
|
||||
bandwidth.Ack(context.Background(), 123456)
|
||||
}
|
||||
@@ -1,14 +1,20 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package ratelimit
|
||||
package ratelimit_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCounter_ACK(t *testing.T) {
|
||||
var counter = NewCounter(10)
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var counter = ratelimit.NewCounter(10)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 10; i++ {
|
||||
@@ -26,7 +32,7 @@ func TestCounter_ACK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCounter_Release(t *testing.T) {
|
||||
var counter = NewCounter(10)
|
||||
var counter = ratelimit.NewCounter(10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
counter.Ack()
|
||||
@@ -34,5 +40,5 @@ func TestCounter_Release(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
counter.Release()
|
||||
}
|
||||
t.Log(len(counter.sem))
|
||||
t.Log(counter.Len())
|
||||
}
|
||||
@@ -12,12 +12,17 @@ type TeeReaderCloser struct {
|
||||
|
||||
onFail func(err error)
|
||||
onEOF func()
|
||||
|
||||
mustWrite bool
|
||||
}
|
||||
|
||||
func NewTeeReaderCloser(reader io.Reader, writer io.Writer) *TeeReaderCloser {
|
||||
// NewTeeReaderCloser
|
||||
// mustWrite - ensure writing MUST be successfully
|
||||
func NewTeeReaderCloser(reader io.Reader, writer io.Writer, mustWrite bool) *TeeReaderCloser {
|
||||
return &TeeReaderCloser{
|
||||
r: reader,
|
||||
w: writer,
|
||||
r: reader,
|
||||
w: writer,
|
||||
mustWrite: mustWrite,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +31,13 @@ func (this *TeeReaderCloser) Read(p []byte) (n int, err error) {
|
||||
if n > 0 {
|
||||
_, wErr := this.w.Write(p[:n])
|
||||
if (err == nil || err == io.EOF) && wErr != nil {
|
||||
err = wErr
|
||||
if this.mustWrite {
|
||||
err = wErr
|
||||
} else {
|
||||
if this.onFail != nil {
|
||||
this.onFail(wErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRawTicker(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var ticker = time.NewTicker(2 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
@@ -21,6 +26,10 @@ func TestRawTicker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTicker(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := NewTicker(3 * time.Second)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
@@ -33,6 +42,10 @@ func TestTicker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTicker2(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := NewTicker(1 * time.Second)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
@@ -50,6 +63,10 @@ func TestTicker2(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTickerEvery(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
i := 0
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -64,8 +81,11 @@ func TestTickerEvery(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
|
||||
func TestTicker_StopTwice(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := NewTicker(3 * time.Second)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
@@ -5,8 +5,18 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AllowScope = string
|
||||
|
||||
const (
|
||||
AllowScopeGroup AllowScope = "group"
|
||||
AllowScopeServer AllowScope = "server"
|
||||
AllowScopeGlobal AllowScope = "global"
|
||||
)
|
||||
|
||||
type AllowAction struct {
|
||||
BaseAction
|
||||
|
||||
Scope AllowScope `yaml:"scope" json:"scope"`
|
||||
}
|
||||
|
||||
func (this *AllowAction) Init(waf *WAF) error {
|
||||
@@ -25,7 +35,12 @@ func (this *AllowAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
// do nothing
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextGroup: this.Scope == AllowScopeGroup,
|
||||
IsAllowed: true,
|
||||
AllowScope: this.Scope,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (this *BlockAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
// 加入到黑名单
|
||||
var timeout = this.Timeout
|
||||
if timeout <= 0 {
|
||||
@@ -93,14 +93,14 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
req, err := http.NewRequest(http.MethodGet, this.URL, nil)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
@@ -124,11 +124,11 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
_, _ = writer.Write(data)
|
||||
}
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
if len(this.Body) > 0 {
|
||||
_, _ = writer.Write([]byte(this.Body))
|
||||
@@ -137,5 +137,5 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
}
|
||||
}
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -123,10 +123,12 @@ func (this *CaptchaAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
var refURL = req.WAFRaw().URL.String()
|
||||
@@ -153,7 +155,9 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 占用一次失败次数
|
||||
@@ -163,5 +167,5 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -41,15 +41,19 @@ func (this *Get302Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
// 仅限于Get
|
||||
if request.WAFRaw().Method != http.MethodGet {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
var m = maps.Map{
|
||||
@@ -64,7 +68,9 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
request.DisableStat()
|
||||
@@ -75,6 +81,6 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
return false, false
|
||||
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -29,20 +29,29 @@ func (this *GoGroupAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true, true
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
|
||||
b, _, nextSet, err := nextGroup.MatchRequest(request)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
|
||||
if !b {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
|
||||
@@ -30,23 +30,35 @@ func (this *GoSetAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true, true
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId))
|
||||
var nextSet = nextGroup.FindRuleSet(types.Int64(this.SetId))
|
||||
if nextSet == nil || !nextSet.IsOn {
|
||||
return true, true
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
|
||||
b, _, err := nextSet.MatchRequest(request)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
if !b {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextSet: true,
|
||||
}
|
||||
}
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
}
|
||||
|
||||
@@ -27,5 +27,5 @@ type ActionInterface interface {
|
||||
WillChange() bool
|
||||
|
||||
// Perform the action
|
||||
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool)
|
||||
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult
|
||||
}
|
||||
|
||||
@@ -42,15 +42,19 @@ func (this *JSCookieAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
var life = this.Life
|
||||
@@ -69,7 +73,9 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
var timestamp = pieces[0]
|
||||
var sum = pieces[2]
|
||||
if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,7 +109,7 @@ window.location.reload();
|
||||
// 记录失败次数
|
||||
this.increaseFails(req, waf.Id, group.Id, set.Id)
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) {
|
||||
|
||||
@@ -25,6 +25,8 @@ func (this *LogAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
return true, false
|
||||
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (this *NotifyAction) WillChange() bool {
|
||||
}
|
||||
|
||||
// Perform the action
|
||||
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
select {
|
||||
case notifyChan <- ¬ifyTask{
|
||||
ServerId: request.WAFServerId(),
|
||||
@@ -89,5 +89,7 @@ func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
|
||||
}
|
||||
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,21 @@ import (
|
||||
type PageAction struct {
|
||||
BaseAction
|
||||
|
||||
Status int `yaml:"status" json:"status"`
|
||||
Body string `yaml:"body" json:"body"`
|
||||
UseDefault bool `yaml:"useDefault" json:"useDefault"`
|
||||
Status int `yaml:"status" json:"status"`
|
||||
Body string `yaml:"body" json:"body"`
|
||||
}
|
||||
|
||||
func (this *PageAction) Init(waf *WAF) error {
|
||||
if waf.DefaultPageAction != nil {
|
||||
if this.Status <= 0 || this.UseDefault {
|
||||
this.Status = waf.DefaultPageAction.Status
|
||||
}
|
||||
if len(this.Body) == 0 || this.UseDefault {
|
||||
this.Body = waf.DefaultPageAction.Body
|
||||
}
|
||||
}
|
||||
|
||||
if this.Status <= 0 {
|
||||
this.Status = http.StatusForbidden
|
||||
}
|
||||
@@ -35,9 +45,9 @@ func (this *PageAction) WillChange() bool {
|
||||
}
|
||||
|
||||
// Perform the action
|
||||
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
if writer == nil {
|
||||
return
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
request.ProcessResponseHeaders(writer.Header(), this.Status)
|
||||
@@ -48,10 +58,12 @@ func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reques
|
||||
if len(body) == 0 {
|
||||
body = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<title>403 Forbidden</title>
|
||||
<head>
|
||||
<title>403 Forbidden</title>
|
||||
<style>
|
||||
address { line-height: 1.8; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>403 Forbidden By WAF</h1>
|
||||
<address>Connection: ${remoteAddr} (Client) -> ${serverAddr} (Server)</address>
|
||||
@@ -61,5 +73,5 @@ func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reques
|
||||
}
|
||||
_, _ = writer.Write([]byte(request.Format(body)))
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -34,17 +34,21 @@ func (this *Post307Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
var cookieName = "WAF_VALIDATOR_ID"
|
||||
|
||||
// 仅限于POST
|
||||
if request.WAFRaw().Method != http.MethodPost {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否有Cookie
|
||||
@@ -58,7 +62,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
}
|
||||
var setId = types.String(m.GetInt64("setId"))
|
||||
SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "")
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +80,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error())
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 清空请求内容
|
||||
@@ -101,5 +109,5 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (this *RecordIPAction) WillChange() bool {
|
||||
return this.Type == "black"
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
var ipListId = this.IPListId
|
||||
if ipListId <= 0 {
|
||||
ipListId = firewallconfigs.GlobalListId
|
||||
@@ -143,7 +143,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
|
||||
// 是否在本地白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
IsAllowed: true,
|
||||
AllowScope: AllowScopeGlobal,
|
||||
}
|
||||
}
|
||||
|
||||
var timeout = this.Timeout
|
||||
@@ -200,5 +204,10 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
}
|
||||
}
|
||||
|
||||
return this.Type != "black", false
|
||||
var isWhite = this.Type != "black"
|
||||
return PerformResult{
|
||||
ContinueRequest: isWhite,
|
||||
IsAllowed: isWhite,
|
||||
AllowScope: AllowScopeGlobal,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,10 +35,10 @@ func (this *RedirectAction) WillChange() bool {
|
||||
}
|
||||
|
||||
// Perform the action
|
||||
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
request.ProcessResponseHeaders(writer.Header(), this.Status)
|
||||
writer.Header().Set("Location", this.URL)
|
||||
writer.WriteHeader(this.Status)
|
||||
|
||||
return false, false
|
||||
return PerformResult{}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ func (this *TagAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
return true, true
|
||||
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +182,7 @@ func (this *CaptchaValidator) showVerifyCodesForm(actionConfig *CaptchaAction, r
|
||||
var msgPrompt string
|
||||
var msgButtonTitle string
|
||||
var msgRequestId string
|
||||
var msgPlaceholder string
|
||||
|
||||
switch lang {
|
||||
case "en-US":
|
||||
@@ -189,16 +190,19 @@ func (this *CaptchaValidator) showVerifyCodesForm(actionConfig *CaptchaAction, r
|
||||
msgPrompt = "Input verify code above:"
|
||||
msgButtonTitle = "Verify Yourself"
|
||||
msgRequestId = "Request ID"
|
||||
msgPlaceholder = ""
|
||||
case "zh-CN":
|
||||
msgTitle = "身份验证"
|
||||
msgPrompt = "请输入上面的验证码"
|
||||
msgButtonTitle = "提交验证"
|
||||
msgRequestId = "请求ID"
|
||||
msgPlaceholder = "点此输入"
|
||||
case "zh-TW":
|
||||
msgTitle = "身份驗證"
|
||||
msgPrompt = "請輸入上面的驗證碼"
|
||||
msgButtonTitle = "提交驗證"
|
||||
msgRequestId = "請求ID"
|
||||
msgPlaceholder = "點此輸入"
|
||||
default:
|
||||
msgTitle = "Verify Yourself"
|
||||
msgPrompt = "Input verify code above:"
|
||||
@@ -232,7 +236,7 @@ func (this *CaptchaValidator) showVerifyCodesForm(actionConfig *CaptchaAction, r
|
||||
}
|
||||
}
|
||||
|
||||
var body = `<form method="POST">
|
||||
var body = `<form method="POST" id="captcha-form">
|
||||
<input type="hidden" name="` + captchaIdName + `" value="` + captchaId + `"/>
|
||||
<div class="ui-image">
|
||||
<p id="ui-captcha-image-prompt">loading ...</p>
|
||||
@@ -240,7 +244,7 @@ func (this *CaptchaValidator) showVerifyCodesForm(actionConfig *CaptchaAction, r
|
||||
</div>
|
||||
<div class="ui-input">
|
||||
<p class="ui-prompt">` + msgPrompt + `</p>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" size="` + types.String(countLetters*17/10) + `" maxlength="` + types.String(countLetters) + `" autocomplete="off" z-index="1" class="input"/>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" size="` + types.String(countLetters*17/10) + `" maxlength="` + types.String(countLetters) + `" autocomplete="off" z-index="1" class="input" placeholder="` + msgPlaceholder + `"/>
|
||||
</div>
|
||||
<div class="ui-button">
|
||||
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
|
||||
@@ -262,28 +266,22 @@ func (this *CaptchaValidator) showVerifyCodesForm(actionConfig *CaptchaAction, r
|
||||
}
|
||||
|
||||
var msgHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<html lang="` + lang + `">
|
||||
<head>
|
||||
<title>` + msgTitle + `</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
|
||||
<meta charset="UTF-8"/>
|
||||
<script type="text/javascript">
|
||||
if (window.addEventListener != null) {
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
document.getElementById("ui-captcha-image").addEventListener("load", function () {
|
||||
var promptBox = document.getElementById("ui-captcha-image-prompt");
|
||||
promptBox.parentNode.removeChild(promptBox);
|
||||
});
|
||||
})
|
||||
window.addEventListener("load", function () {
|
||||
document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus();
|
||||
});
|
||||
}
|
||||
var isValidated=!1;window.addEventListener("pageshow",function(){isValidated&&window.location.reload()}),null!=window.addEventListener&&(document.addEventListener("DOMContentLoaded",function(){document.getElementById("ui-captcha-image").addEventListener("load",function(){var e=document.getElementById("ui-captcha-image-prompt");e.parentNode.removeChild(e)})}),window.addEventListener("load",function(){document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus();var e=document.getElementById("captcha-form");null!=e&&e.addEventListener("submit",function(){isValidated=!0})}));
|
||||
</script>
|
||||
<style type="text/css">
|
||||
form { max-width: 20em; margin: 0 auto; text-align: center; }
|
||||
.input { font-size:16px;line-height:24px; letter-spacing:0.2em; min-width: 5em; text-align: center; }
|
||||
* { font-size: 13px; }
|
||||
form { max-width: 20em; margin: 0 auto; text-align: center; font-family: Roboto,"Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif; }
|
||||
.ui-prompt { font-size: 1.2rem; }
|
||||
.input { font-size:16px;line-height:24px; letter-spacing:0.2em; min-width: 5em; text-align: center; background: #fff; border: 1px solid rgba(0, 0, 0, 0.38); color: rgba(0, 0, 0, 0.87); outline: none; border-radius: 4px; padding: 0.75rem 0.75rem; }
|
||||
.input:focus { border: 1px #3f51b5 solid; outline: none; }
|
||||
address { margin-top: 1em; padding-top: 0.5em; border-top: 1px #ccc solid; text-align: center; }
|
||||
button { background: #3f51b5; color: #fff; cursor: pointer; padding: 0.571rem 0.75rem; min-width: 8rem; font-size: 1rem; border: 0 none; border-radius: 4px; }
|
||||
` + msgCss + `
|
||||
</style>
|
||||
</head>
|
||||
@@ -438,10 +436,10 @@ func (this *CaptchaValidator) showOneClickForm(actionConfig *CaptchaAction, req
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
|
||||
<meta charset="UTF-8"/>
|
||||
<script type="text/javascript">
|
||||
window.addEventListener("load",function(){var t=document.getElementById("checkbox"),n=!1;t.addEventListener("click",function(){var e;t.className="ui-checkbox checked",n||(n=!0,(e=document.createElement("input")).setAttribute("name","nonce"),e.setAttribute("type","hidden"),e.setAttribute("value","` + types.String(nonce) + `"),document.getElementById("ui-form").appendChild(e),document.getElementById("ui-form").submit())})});
|
||||
var isValidated=!1;window.addEventListener("pageshow",function(){isValidated&&window.location.reload()}),window.addEventListener("load",function(){var t=document.getElementById("checkbox"),n=!1;t.addEventListener("click",function(){var e;t.className="ui-checkbox checked",n||(isValidated=n=!0,(e=document.createElement("input")).setAttribute("name","nonce"),e.setAttribute("type","hidden"),e.setAttribute("value","` + types.String(nonce) + `"),document.getElementById("ui-form").appendChild(e),document.getElementById("ui-form").submit())})});
|
||||
</script>
|
||||
<style type="text/css">
|
||||
form { max-width: 20em; margin: 0 auto; text-align: center; }
|
||||
form { max-width: 20em; margin: 0 auto; text-align: center; font-family: Roboto,"Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif; }
|
||||
.ui-input { position: relative; padding-top: 1em; height: 2.2em; background: #eee; }
|
||||
.ui-checkbox { width: 16px; height: 16px; border: 1px #999 solid; float: left; margin-left: 1em; cursor: pointer; }
|
||||
.ui-checkbox.checked { background: #276AC6; }
|
||||
@@ -595,10 +593,10 @@ func (this *CaptchaValidator) showSlideForm(actionConfig *CaptchaAction, req req
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
|
||||
<meta charset="UTF-8"/>
|
||||
<script type="text/javascript">
|
||||
window.addEventListener("load",function(){var n=document.getElementById("input"),s=document.getElementById("handler"),o=document.getElementById("progress-bar"),d=!1,u=0,t=n.offsetLeft,c=s.offsetLeft,i=n.offsetWidth-s.offsetWidth-s.offsetLeft,f=!1;function e(e){e.preventDefault(),d=!0,u=null!=e.touches&&0<e.touches.length?e.touches[0].clientX-t:e.offsetX}function l(e){var t;d&&(t=e.x,null!=e.touches&&0<e.touches.length&&(t=e.touches[0].clientX),(t=t-n.offsetLeft-u)<c?t=c:i<t&&(t=i),s.style.cssText="margin-left: "+t+"px",0<t&&(o.style.cssText="width: "+(t+s.offsetWidth+4)+"px"))}function r(e){var t;d=d&&!1,s.offsetLeft<i-4?(s.style.cssText="margin-left: "+c+"px",n.style.cssText="background: #eee",o.style.cssText="width: 0px"):(s.style.cssText="margin-left: "+i+"px",n.style.cssText="background: #a5dc86",f||(f=!0,(t=document.createElement("input")).setAttribute("name","nonce"),t.setAttribute("type","hidden"),t.setAttribute("value","` + types.String(nonce) + `"),document.getElementById("ui-form").appendChild(t),document.getElementById("ui-form").submit()))}void 0!==document.ontouchstart?(s.addEventListener("touchstart",e),document.addEventListener("touchmove",l),document.addEventListener("touchend",r)):(s.addEventListener("mousedown",e),window.addEventListener("mousemove",l),window.addEventListener("mouseup",r))});
|
||||
var isValidated=!1;window.addEventListener("pageshow",function(){isValidated&&window.location.reload()}),window.addEventListener("load",function(){var n=document.getElementById("input"),s=document.getElementById("handler"),d=document.getElementById("progress-bar"),o=!1,i=0,t=n.offsetLeft,u=s.offsetLeft,c=n.offsetWidth-s.offsetWidth-s.offsetLeft,a=!1;function e(e){e.preventDefault(),o=!0,i=null!=e.touches&&0<e.touches.length?e.touches[0].clientX-t:e.offsetX}function f(e){var t;o&&(t=e.x,null!=e.touches&&0<e.touches.length&&(t=e.touches[0].clientX),(t=t-n.offsetLeft-i)<u?t=u:c<t&&(t=c),s.style.cssText="margin-left: "+t+"px",0<t&&(d.style.cssText="width: "+(t+s.offsetWidth+4)+"px"))}function l(e){var t;o=o&&!1,s.offsetLeft<c-4?(s.style.cssText="margin-left: "+u+"px",n.style.cssText="background: #eee",d.style.cssText="width: 0px"):(s.style.cssText="margin-left: "+c+"px",n.style.cssText="background: #a5dc86",a||(isValidated=a=!0,(t=document.createElement("input")).setAttribute("name","nonce"),t.setAttribute("type","hidden"),t.setAttribute("value","` + types.String(nonce) + `"),document.getElementById("ui-form").appendChild(t),document.getElementById("ui-form").submit()))}void 0!==document.ontouchstart?(s.addEventListener("touchstart",e),document.addEventListener("touchmove",f),document.addEventListener("touchend",l)):(s.addEventListener("mousedown",e),window.addEventListener("mousemove",f),window.addEventListener("mouseup",l))});
|
||||
</script>
|
||||
<style type="text/css">
|
||||
form { max-width: 20em; margin: 5em auto; text-align: center; }
|
||||
form { max-width: 20em; margin: 5em auto; text-align: center; font-family: Roboto,"Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif; }
|
||||
.ui-input {
|
||||
height: 4em;
|
||||
background: #eee;
|
||||
@@ -608,7 +606,7 @@ window.addEventListener("load",function(){var n=document.getElementById("input")
|
||||
}
|
||||
|
||||
.ui-input .ui-progress-bar {
|
||||
background: #a5dc86;
|
||||
background: #689F38;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
@@ -619,7 +617,7 @@ window.addEventListener("load",function(){var n=document.getElementById("input")
|
||||
width: 3.6em;
|
||||
height: 3.6em;
|
||||
margin: 0.2em;
|
||||
background: #276AC6;
|
||||
background: #3f51b5;
|
||||
border-radius: 0.6em;
|
||||
display: inline-block;
|
||||
cursor: pointer;
|
||||
|
||||
@@ -58,7 +58,7 @@ int libinjection_sqli(const char* s, size_t slen, char fingerprint[]);
|
||||
* \return 1 if XSS found, 0 if benign
|
||||
*
|
||||
*/
|
||||
int libinjection_xss(const char* s, size_t slen);
|
||||
int libinjection_xss(const char* s, size_t slen, int strictMode);
|
||||
|
||||
LIBINJECTION_END_DECLS
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ typedef enum attribute {
|
||||
} attribute_t;
|
||||
|
||||
|
||||
static attribute_t is_black_attr(const char* s, size_t len);
|
||||
static int is_black_tag(const char* s, size_t len);
|
||||
static attribute_t is_black_attr(const char* s, size_t len, int strictMode);
|
||||
static int is_black_tag(const char* s, size_t len, int strictMode);
|
||||
static int is_black_url(const char* s, size_t len);
|
||||
static int cstrcasecmp_with_null(const char *a, const char *b, size_t n);
|
||||
static int html_decode_char_at(const char* src, size_t len, size_t* consumed);
|
||||
@@ -451,7 +451,7 @@ static stringtype_t BLACKATTREVENT[] = {
|
||||
* data:
|
||||
* javascript:
|
||||
*/
|
||||
static stringtype_t BLACKATTR[] = {
|
||||
static stringtype_t STRICT_BLACKATTR[] = {
|
||||
{ "ACTION", TYPE_ATTR_URL } /* form */
|
||||
, { "ATTRIBUTENAME", TYPE_ATTR_INDIRECT } /* SVG allow indirection of attribute names */
|
||||
, { "BY", TYPE_ATTR_URL } /* SVG */
|
||||
@@ -475,6 +475,31 @@ static stringtype_t BLACKATTR[] = {
|
||||
, { NULL, TYPE_NONE }
|
||||
};
|
||||
|
||||
static stringtype_t BLACKATTR[] = {
|
||||
{ "ACTION", TYPE_ATTR_URL } /* form */
|
||||
, { "ATTRIBUTENAME", TYPE_ATTR_INDIRECT } /* SVG allow indirection of attribute names */
|
||||
, { "BY", TYPE_ATTR_URL } /* SVG */
|
||||
, { "BACKGROUND", TYPE_ATTR_URL } /* IE6, O11 */
|
||||
, { "DATAFORMATAS", TYPE_BLACK } /* IE */
|
||||
, { "DATASRC", TYPE_BLACK } /* IE */
|
||||
, { "DYNSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
|
||||
, { "FILTER", TYPE_STYLE } /* Opera, SVG inline style */
|
||||
, { "FORMACTION", TYPE_ATTR_URL } /* HTML 5 */
|
||||
, { "FOLDER", TYPE_ATTR_URL } /* Only on A tags, IE-only */
|
||||
, { "FROM", TYPE_ATTR_URL } /* SVG */
|
||||
, { "HANDLER", TYPE_ATTR_URL } /* SVG Tiny, Opera */
|
||||
, { "HREF", TYPE_ATTR_URL }
|
||||
, { "LOWSRC", TYPE_ATTR_URL } /* Obsolete img attribute */
|
||||
, { "POSTER", TYPE_ATTR_URL } /* Opera 10,11 */
|
||||
, { "SRC", TYPE_ATTR_URL }
|
||||
, { "TO", TYPE_ATTR_URL } /* SVG */
|
||||
, { "VALUES", TYPE_ATTR_URL } /* SVG */
|
||||
, { "XLINK:HREF", TYPE_ATTR_URL }
|
||||
, { NULL, TYPE_NONE }
|
||||
};
|
||||
|
||||
|
||||
|
||||
/* xmlns */
|
||||
/* `xml-stylesheet` > <eval>, <if expr=> */
|
||||
|
||||
@@ -492,6 +517,35 @@ static stringtype_t BLACKATTR[] = {
|
||||
};
|
||||
*/
|
||||
|
||||
// GoEdge: change BLACKTAG to STRICT_BLACKTAG
|
||||
static const char* STRICT_BLACKTAG[] = {
|
||||
"APPLET"
|
||||
, "AUDIO"
|
||||
, "BASE"
|
||||
, "COMMENT" /* IE http://html5sec.org/#38 */
|
||||
, "EMBED"
|
||||
, "FORM"
|
||||
, "FRAME"
|
||||
, "FRAMESET"
|
||||
, "HANDLER" /* Opera SVG, effectively a script tag */
|
||||
, "IFRAME"
|
||||
, "IMPORT"
|
||||
, "ISINDEX"
|
||||
, "LINK"
|
||||
, "LISTENER"
|
||||
/* , "MARQUEE" */
|
||||
, "META"
|
||||
, "NOSCRIPT"
|
||||
, "OBJECT"
|
||||
, "SCRIPT"
|
||||
, "STYLE"
|
||||
, "VIDEO"
|
||||
, "VMLFRAME"
|
||||
, "XML"
|
||||
, "XSS"
|
||||
, NULL
|
||||
};
|
||||
|
||||
static const char* BLACKTAG[] = {
|
||||
"APPLET"
|
||||
/* , "AUDIO" */
|
||||
@@ -515,7 +569,6 @@ static const char* BLACKTAG[] = {
|
||||
, "STYLE"
|
||||
/* , "VIDEO" */
|
||||
, "VMLFRAME"
|
||||
, "XML"
|
||||
, "XSS"
|
||||
, NULL
|
||||
};
|
||||
@@ -606,7 +659,7 @@ static int htmlencode_startswith(const char *a, const char *b, size_t n)
|
||||
return (*a == 0) ? 1 : 0;
|
||||
}
|
||||
|
||||
static int is_black_tag(const char* s, size_t len)
|
||||
static int is_black_tag(const char* s, size_t len, int strictMode)
|
||||
{
|
||||
const char** black;
|
||||
|
||||
@@ -614,7 +667,11 @@ static int is_black_tag(const char* s, size_t len)
|
||||
return 0;
|
||||
}
|
||||
|
||||
black = BLACKTAG;
|
||||
if (strictMode == 1) {
|
||||
black = STRICT_BLACKTAG;
|
||||
} else {
|
||||
black = BLACKTAG;
|
||||
}
|
||||
while (*black != NULL) {
|
||||
if (cstrcasecmp_with_null(*black, s, len) == 0) {
|
||||
/* printf("Got black tag %s\n", *black); */
|
||||
@@ -642,7 +699,7 @@ static int is_black_tag(const char* s, size_t len)
|
||||
return 0;
|
||||
}
|
||||
|
||||
static attribute_t is_black_attr(const char* s, size_t len)
|
||||
static attribute_t is_black_attr(const char* s, size_t len, int strictMode)
|
||||
{
|
||||
stringtype_t* black;
|
||||
|
||||
@@ -674,7 +731,11 @@ static attribute_t is_black_attr(const char* s, size_t len)
|
||||
//}
|
||||
}
|
||||
|
||||
black = BLACKATTR;
|
||||
if (strictMode == 1) {
|
||||
black = STRICT_BLACKATTR;
|
||||
} else {
|
||||
black = BLACKATTR;
|
||||
}
|
||||
while (black->name != NULL) {
|
||||
if (cstrcasecmp_with_null(black->name, s, len) == 0) {
|
||||
/* printf("Got banned attribute name %s\n", black->name); */
|
||||
@@ -729,7 +790,7 @@ static int is_black_url(const char* s, size_t len)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int libinjection_is_xss(const char* s, size_t len, int flags)
|
||||
int libinjection_is_xss(const char* s, size_t len, int flags, int strictMode)
|
||||
{
|
||||
h5_state_t h5;
|
||||
attribute_t attr = TYPE_NONE;
|
||||
@@ -743,11 +804,11 @@ int libinjection_is_xss(const char* s, size_t len, int flags)
|
||||
if (h5.token_type == DOCTYPE) {
|
||||
return 1;
|
||||
} else if (h5.token_type == TAG_NAME_OPEN) {
|
||||
if (is_black_tag(h5.token_start, h5.token_len)) {
|
||||
if (is_black_tag(h5.token_start, h5.token_len, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
} else if (h5.token_type == ATTR_NAME) {
|
||||
attr = is_black_attr(h5.token_start, h5.token_len);
|
||||
attr = is_black_attr(h5.token_start, h5.token_len, strictMode);
|
||||
} else if (h5.token_type == ATTR_VALUE) {
|
||||
/*
|
||||
* IE6,7,8 parsing works a bit differently so
|
||||
@@ -778,7 +839,7 @@ int libinjection_is_xss(const char* s, size_t len, int flags)
|
||||
return 1;
|
||||
case TYPE_ATTR_INDIRECT:
|
||||
/* an attribute name is specified in a _value_ */
|
||||
if (is_black_attr(h5.token_start, h5.token_len)) {
|
||||
if (is_black_attr(h5.token_start, h5.token_len, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
break;
|
||||
@@ -835,21 +896,21 @@ int libinjection_is_xss(const char* s, size_t len, int flags)
|
||||
*
|
||||
*
|
||||
*/
|
||||
int libinjection_xss(const char* s, size_t slen)
|
||||
int libinjection_xss(const char* s, size_t slen, int strictMode)
|
||||
{
|
||||
if (libinjection_is_xss(s, slen, DATA_STATE)) {
|
||||
if (libinjection_is_xss(s, slen, DATA_STATE, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
if (libinjection_is_xss(s, slen, VALUE_NO_QUOTE)) {
|
||||
if (libinjection_is_xss(s, slen, VALUE_NO_QUOTE, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
if (libinjection_is_xss(s, slen, VALUE_SINGLE_QUOTE)) {
|
||||
if (libinjection_is_xss(s, slen, VALUE_SINGLE_QUOTE, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
if (libinjection_is_xss(s, slen, VALUE_DOUBLE_QUOTE)) {
|
||||
if (libinjection_is_xss(s, slen, VALUE_DOUBLE_QUOTE, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
if (libinjection_is_xss(s, slen, VALUE_BACK_QUOTE)) {
|
||||
if (libinjection_is_xss(s, slen, VALUE_BACK_QUOTE, strictMode)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ extern "C" {
|
||||
|
||||
#include <string.h>
|
||||
|
||||
int libinjection_is_xss(const char* s, size_t len, int flags);
|
||||
int libinjection_is_xss(const char* s, size_t len, int flags, int strictMode);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
#include "libinjection/src/libinjection_xss.c"
|
||||
#include "libinjection/src/libinjection_html5.c"
|
||||
|
||||
#define GOEDGE_VERSION "23" // last version is for GoEdge change
|
||||
#define GOEDGE_VERSION "25" // last version is for GoEdge change
|
||||
@@ -3,7 +3,7 @@
|
||||
package injectionutils
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I./libinjection/src
|
||||
#cgo CFLAGS: -O2 -I./libinjection/src
|
||||
|
||||
#include <libinjection.h>
|
||||
#include <stdlib.h>
|
||||
@@ -16,11 +16,12 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// DetectSQLInjectionCache detect sql injection in string with cache
|
||||
func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
|
||||
func DetectSQLInjectionCache(input string, isStrict bool, cacheLife utils.CacheLife) bool {
|
||||
var l = len(input)
|
||||
|
||||
if l == 0 {
|
||||
@@ -28,7 +29,7 @@ func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
|
||||
}
|
||||
|
||||
if cacheLife <= 0 || l < 128 || l > utils.MaxCacheDataSize {
|
||||
return DetectSQLInjection(input)
|
||||
return DetectSQLInjection(input, isStrict)
|
||||
}
|
||||
|
||||
var hash = xxhash.Sum64String(input)
|
||||
@@ -38,7 +39,7 @@ func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
|
||||
return item.Value == 1
|
||||
}
|
||||
|
||||
var result = DetectSQLInjection(input)
|
||||
var result = DetectSQLInjection(input, isStrict)
|
||||
if result {
|
||||
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
|
||||
} else {
|
||||
@@ -48,11 +49,23 @@ func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
|
||||
}
|
||||
|
||||
// DetectSQLInjection detect sql injection in string
|
||||
func DetectSQLInjection(input string) bool {
|
||||
func DetectSQLInjection(input string, isStrict bool) bool {
|
||||
if len(input) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isStrict {
|
||||
if len(input) > 1024 {
|
||||
if !utf8.ValidString(input[:1024]) && !utf8.ValidString(input[:1023]) && !utf8.ValidString(input[:1022]) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if !utf8.ValidString(input) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if detectSQLInjectionOne(input) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -15,21 +15,23 @@ import (
|
||||
|
||||
func TestDetectSQLInjection(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("' UNION SELECT * FROM myTable"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=1 ' UNION select * from a"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--"))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("' UNION SELECT1 * FROM myTable"))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("1234"))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection(""))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123 OR 1=1&b=2"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123&b=456&c=1' or 2=2"))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("?"))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("/hello?age=22"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123%20or%201=1"))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1"))
|
||||
for _, isStrict := range []bool{true, false} {
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("' UNION SELECT * FROM myTable", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=1 ' UNION select * from a", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--", isStrict))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("' UNION SELECT1 * FROM myTable", isStrict))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("1234", isStrict))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123 OR 1=1&b=2", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123&b=456&c=1' or 2=2", isStrict))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("?", isStrict))
|
||||
a.IsFalse(injectionutils.DetectSQLInjection("/hello?age=22", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("id=123%20or%201=1", isStrict))
|
||||
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1", isStrict))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDetectSQLInjection(b *testing.B) {
|
||||
@@ -37,7 +39,7 @@ func BenchmarkDetectSQLInjection(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--")
|
||||
_ = injectionutils.DetectSQLInjection("asdf asd ; -1' and 1=1 union/* foo */select load_file('/etc/passwd')--", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -47,7 +49,7 @@ func BenchmarkDetectSQLInjection_URL(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1")
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -57,7 +59,7 @@ func BenchmarkDetectSQLInjection_Normal_Small(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234")
|
||||
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -67,7 +69,7 @@ func BenchmarkDetectSQLInjection_URL_Normal_Small(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id=" + types.String(rands.Int64()%10000))
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id="+types.String(rands.Int64()%10000), false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -77,7 +79,7 @@ func BenchmarkDetectSQLInjection_URL_Normal_Middle(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1")
|
||||
_ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -87,7 +89,7 @@ func BenchmarkDetectSQLInjection_URL_Normal_Small_Cache(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), utils.CacheMiddleLife)
|
||||
_ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), false, utils.CacheMiddleLife)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -100,7 +102,7 @@ func BenchmarkDetectSQLInjection_Normal_Large(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=" + types.String(rands.Int64()%10000) + "&s=" + s + "&v=%20")
|
||||
_ = injectionutils.DetectSQLInjection("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s+"&v=%20", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -112,7 +114,7 @@ func BenchmarkDetectSQLInjection_Normal_Large_Cache(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, utils.CacheMiddleLife)
|
||||
_ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, false, utils.CacheMiddleLife)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -122,7 +124,7 @@ func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1")
|
||||
_ = injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package injectionutils
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I./libinjection/src
|
||||
#cgo CFLAGS: -O2 -I./libinjection/src
|
||||
|
||||
#include <libinjection.h>
|
||||
#include <stdlib.h>
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
|
||||
func DetectXSSCache(input string, isStrict bool, cacheLife utils.CacheLife) bool {
|
||||
var l = len(input)
|
||||
|
||||
if l == 0 {
|
||||
@@ -27,17 +27,20 @@ func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
|
||||
}
|
||||
|
||||
if cacheLife <= 0 || l < 512 || l > utils.MaxCacheDataSize {
|
||||
return DetectXSS(input)
|
||||
return DetectXSS(input, isStrict)
|
||||
}
|
||||
|
||||
var hash = xxhash.Sum64String(input)
|
||||
var key = "WAF@XSS@" + strconv.FormatUint(hash, 10)
|
||||
if isStrict {
|
||||
key += "@1"
|
||||
}
|
||||
var item = utils.SharedCache.Read(key)
|
||||
if item != nil {
|
||||
return item.Value == 1
|
||||
}
|
||||
|
||||
var result = DetectXSS(input)
|
||||
var result = DetectXSS(input, isStrict)
|
||||
if result {
|
||||
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
|
||||
} else {
|
||||
@@ -47,12 +50,12 @@ func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
|
||||
}
|
||||
|
||||
// DetectXSS detect XSS in string
|
||||
func DetectXSS(input string) bool {
|
||||
func DetectXSS(input string, isStrict bool) bool {
|
||||
if len(input) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if detectXSSOne(input) {
|
||||
if detectXSSOne(input, isStrict) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -63,22 +66,22 @@ func DetectXSS(input string) bool {
|
||||
var args = input[argsIndex+1:]
|
||||
unescapeArgs, err := url.QueryUnescape(args)
|
||||
if err == nil && args != unescapeArgs {
|
||||
return detectXSSOne(args) || detectXSSOne(unescapeArgs)
|
||||
return detectXSSOne(args, isStrict) || detectXSSOne(unescapeArgs, isStrict)
|
||||
} else {
|
||||
return detectXSSOne(args)
|
||||
return detectXSSOne(args, isStrict)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unescapedInput, err := url.QueryUnescape(input)
|
||||
if err == nil && input != unescapedInput {
|
||||
return detectXSSOne(unescapedInput)
|
||||
return detectXSSOne(unescapedInput, isStrict)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func detectXSSOne(input string) bool {
|
||||
func detectXSSOne(input string, isStrict bool) bool {
|
||||
if len(input) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -86,5 +89,9 @@ func detectXSSOne(input string) bool {
|
||||
var cInput = C.CString(input)
|
||||
defer C.free(unsafe.Pointer(cInput))
|
||||
|
||||
return C.libinjection_xss(cInput, C.size_t(len(input))) == 1
|
||||
var isStrictInt = 0
|
||||
if isStrict {
|
||||
isStrictInt = 1
|
||||
}
|
||||
return C.libinjection_xss(cInput, C.size_t(len(input)), C.int(isStrictInt)) == 1
|
||||
}
|
||||
|
||||
@@ -12,18 +12,18 @@ import (
|
||||
|
||||
func TestDetectXSS(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(injectionutils.DetectXSS(""))
|
||||
a.IsFalse(injectionutils.DetectXSS("abc"))
|
||||
a.IsTrue(injectionutils.DetectXSS("<script>"))
|
||||
a.IsTrue(injectionutils.DetectXSS("<link>"))
|
||||
a.IsFalse(injectionutils.DetectXSS("<html><span>"))
|
||||
a.IsFalse(injectionutils.DetectXSS("<script>"))
|
||||
a.IsTrue(injectionutils.DetectXSS("/path?onmousedown=a"))
|
||||
a.IsTrue(injectionutils.DetectXSS("/path?onkeyup=a"))
|
||||
a.IsTrue(injectionutils.DetectXSS("onkeyup=a"))
|
||||
a.IsTrue(injectionutils.DetectXSS("<iframe scrolling='no'>"))
|
||||
a.IsFalse(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>"))
|
||||
a.IsTrue(injectionutils.DetectXSS("name=s&description=%3Cscript+src%3D%22a.js%22%3Edddd%3C%2Fscript%3E"))
|
||||
a.IsFalse(injectionutils.DetectXSS("", true))
|
||||
a.IsFalse(injectionutils.DetectXSS("abc", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("<script>", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("<link>", true))
|
||||
a.IsFalse(injectionutils.DetectXSS("<html><span>", true))
|
||||
a.IsFalse(injectionutils.DetectXSS("<script>", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("/path?onmousedown=a", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("/path?onkeyup=a", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("onkeyup=a", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("<iframe scrolling='no'>", true))
|
||||
a.IsFalse(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>", true))
|
||||
a.IsTrue(injectionutils.DetectXSS("name=s&description=%3Cscript+src%3D%22a.js%22%3Edddd%3C%2Fscript%3E", true))
|
||||
a.IsFalse(injectionutils.DetectXSS(`<x:xmpmeta xmlns:x="adobe:ns:meta/" x:xmptk="XMP Core 6.0.0">
|
||||
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
||||
<rdf:Description rdf:about=""
|
||||
@@ -31,11 +31,25 @@ func TestDetectXSS(t *testing.T) {
|
||||
<tiff:Orientation>1</tiff:Orientation>
|
||||
</rdf:Description>
|
||||
</rdf:RDF>
|
||||
</x:xmpmeta>`)) // included in some photo files
|
||||
</x:xmpmeta>`, true)) // included in some photo files
|
||||
a.IsFalse(injectionutils.DetectXSS(`<xml></xml>`, false))
|
||||
}
|
||||
|
||||
func TestDetectXSS_Strict(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(injectionutils.DetectXSS(`<xml></xml>`, false))
|
||||
a.IsTrue(injectionutils.DetectXSS(`<xml></xml>`, true))
|
||||
a.IsFalse(injectionutils.DetectXSS(`<img src=\"\"/>`, false))
|
||||
a.IsFalse(injectionutils.DetectXSS(`<img src=\"test.jpg\"/>`, true))
|
||||
a.IsFalse(injectionutils.DetectXSS(`<a href="aaaa"></a>`, true))
|
||||
a.IsFalse(injectionutils.DetectXSS(`<span style="color: red"></span>`, false))
|
||||
a.IsTrue(injectionutils.DetectXSS(`<span style="color: red"></span>`, true))
|
||||
a.IsFalse(injectionutils.DetectXSS("https://example.com?style=list", false))
|
||||
a.IsTrue(injectionutils.DetectXSS("https://example.com?style=list", true))
|
||||
}
|
||||
|
||||
func BenchmarkDetectXSS_MISS(b *testing.B) {
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>", false)
|
||||
if result {
|
||||
b.Fatal("'result' should not be 'true'")
|
||||
}
|
||||
@@ -44,13 +58,13 @@ func BenchmarkDetectXSS_MISS(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
|
||||
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDetectXSS_MISS_Cache(b *testing.B) {
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>", false)
|
||||
if result {
|
||||
b.Fatal("'result' should not be 'true'")
|
||||
}
|
||||
@@ -59,13 +73,13 @@ func BenchmarkDetectXSS_MISS_Cache(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectXSSCache("<html><body><span>RequestId: 1234567890</span></body></html>", utils.CacheMiddleLife)
|
||||
_ = injectionutils.DetectXSSCache("<html><body><span>RequestId: 1234567890</span></body></html>", false, utils.CacheMiddleLife)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDetectXSS_HIT(b *testing.B) {
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
|
||||
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>", false)
|
||||
if !result {
|
||||
b.Fatal("'result' should not be 'false'")
|
||||
}
|
||||
@@ -74,7 +88,7 @@ func BenchmarkDetectXSS_HIT(b *testing.B) {
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
|
||||
_ = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>", false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,8 +45,10 @@ func init() {
|
||||
|
||||
// load
|
||||
go func() {
|
||||
_ = SharedIPWhiteList.Load(cacheFile)
|
||||
_ = os.Remove(cacheFile)
|
||||
if !Tea.IsTesting() {
|
||||
_ = SharedIPWhiteList.Load(cacheFile)
|
||||
_ = os.Remove(cacheFile)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ package waf_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
@@ -32,6 +33,10 @@ func TestNewIPList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPList_Expire(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
var list = waf.NewIPList(waf.IPListTypeDeny)
|
||||
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
|
||||
list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
|
||||
|
||||
22
internal/waf/results.go
Normal file
22
internal/waf/results.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package waf
|
||||
|
||||
// PerformResult action performing result
|
||||
type PerformResult struct {
|
||||
ContinueRequest bool
|
||||
GoNextGroup bool
|
||||
GoNextSet bool
|
||||
IsAllowed bool
|
||||
AllowScope AllowScope
|
||||
}
|
||||
|
||||
// MatchResult request match result
|
||||
type MatchResult struct {
|
||||
GoNext bool
|
||||
HasRequestBody bool
|
||||
Group *RuleGroup
|
||||
Set *RuleSet
|
||||
IsAllowed bool
|
||||
AllowScope AllowScope
|
||||
}
|
||||
@@ -578,49 +578,51 @@ func (this *Rule) Test(value any) bool {
|
||||
return runes.ContainsAllWords(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive)
|
||||
case RuleOperatorNotContainsAnyWord:
|
||||
return !runes.ContainsAnyWordRunes(this.stringifyValue(value), this.stringValueRunes, this.IsCaseInsensitive)
|
||||
case RuleOperatorContainsSQLInjection:
|
||||
case RuleOperatorContainsSQLInjection, RuleOperatorContainsSQLInjectionStrictly:
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
var isStrict = this.Operator == RuleOperatorContainsSQLInjectionStrictly
|
||||
switch xValue := value.(type) {
|
||||
case []string:
|
||||
for _, v := range xValue {
|
||||
if injectionutils.DetectSQLInjectionCache(v, this.cacheLife) {
|
||||
if injectionutils.DetectSQLInjectionCache(v, isStrict, this.cacheLife) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case [][]byte:
|
||||
for _, v := range xValue {
|
||||
if injectionutils.DetectSQLInjectionCache(string(v), this.cacheLife) {
|
||||
if injectionutils.DetectSQLInjectionCache(string(v), isStrict, this.cacheLife) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), this.cacheLife)
|
||||
return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), isStrict, this.cacheLife)
|
||||
}
|
||||
case RuleOperatorContainsXSS:
|
||||
case RuleOperatorContainsXSS, RuleOperatorContainsXSSStrictly:
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
var isStrict = this.Operator == RuleOperatorContainsXSSStrictly
|
||||
switch xValue := value.(type) {
|
||||
case []string:
|
||||
for _, v := range xValue {
|
||||
if injectionutils.DetectXSSCache(v, this.cacheLife) {
|
||||
if injectionutils.DetectXSSCache(v, isStrict, this.cacheLife) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case [][]byte:
|
||||
for _, v := range xValue {
|
||||
if injectionutils.DetectXSSCache(string(v), this.cacheLife) {
|
||||
if injectionutils.DetectXSSCache(string(v), isStrict, this.cacheLife) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return injectionutils.DetectXSSCache(this.stringifyValue(value), this.cacheLife)
|
||||
return injectionutils.DetectXSSCache(this.stringifyValue(value), isStrict, this.cacheLife)
|
||||
}
|
||||
case RuleOperatorContainsBinary:
|
||||
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
|
||||
|
||||
@@ -4,34 +4,36 @@ type RuleOperator = string
|
||||
type RuleCaseInsensitive = string
|
||||
|
||||
const (
|
||||
RuleOperatorGt RuleOperator = "gt"
|
||||
RuleOperatorGte RuleOperator = "gte"
|
||||
RuleOperatorLt RuleOperator = "lt"
|
||||
RuleOperatorLte RuleOperator = "lte"
|
||||
RuleOperatorEq RuleOperator = "eq"
|
||||
RuleOperatorNeq RuleOperator = "neq"
|
||||
RuleOperatorEqString RuleOperator = "eq string"
|
||||
RuleOperatorNeqString RuleOperator = "neq string"
|
||||
RuleOperatorMatch RuleOperator = "match"
|
||||
RuleOperatorNotMatch RuleOperator = "not match"
|
||||
RuleOperatorWildcardMatch RuleOperator = "wildcard match"
|
||||
RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match"
|
||||
RuleOperatorContains RuleOperator = "contains"
|
||||
RuleOperatorNotContains RuleOperator = "not contains"
|
||||
RuleOperatorPrefix RuleOperator = "prefix"
|
||||
RuleOperatorSuffix RuleOperator = "suffix"
|
||||
RuleOperatorContainsAny RuleOperator = "contains any"
|
||||
RuleOperatorContainsAll RuleOperator = "contains all"
|
||||
RuleOperatorContainsAnyWord RuleOperator = "contains any word"
|
||||
RuleOperatorContainsAllWords RuleOperator = "contains all words"
|
||||
RuleOperatorNotContainsAnyWord RuleOperator = "not contains any word"
|
||||
RuleOperatorContainsSQLInjection RuleOperator = "contains sql injection"
|
||||
RuleOperatorContainsXSS RuleOperator = "contains xss"
|
||||
RuleOperatorInIPList RuleOperator = "in ip list"
|
||||
RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map
|
||||
RuleOperatorVersionGt RuleOperator = "version gt"
|
||||
RuleOperatorVersionLt RuleOperator = "version lt"
|
||||
RuleOperatorVersionRange RuleOperator = "version range"
|
||||
RuleOperatorGt RuleOperator = "gt"
|
||||
RuleOperatorGte RuleOperator = "gte"
|
||||
RuleOperatorLt RuleOperator = "lt"
|
||||
RuleOperatorLte RuleOperator = "lte"
|
||||
RuleOperatorEq RuleOperator = "eq"
|
||||
RuleOperatorNeq RuleOperator = "neq"
|
||||
RuleOperatorEqString RuleOperator = "eq string"
|
||||
RuleOperatorNeqString RuleOperator = "neq string"
|
||||
RuleOperatorMatch RuleOperator = "match"
|
||||
RuleOperatorNotMatch RuleOperator = "not match"
|
||||
RuleOperatorWildcardMatch RuleOperator = "wildcard match"
|
||||
RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match"
|
||||
RuleOperatorContains RuleOperator = "contains"
|
||||
RuleOperatorNotContains RuleOperator = "not contains"
|
||||
RuleOperatorPrefix RuleOperator = "prefix"
|
||||
RuleOperatorSuffix RuleOperator = "suffix"
|
||||
RuleOperatorContainsAny RuleOperator = "contains any"
|
||||
RuleOperatorContainsAll RuleOperator = "contains all"
|
||||
RuleOperatorContainsAnyWord RuleOperator = "contains any word"
|
||||
RuleOperatorContainsAllWords RuleOperator = "contains all words"
|
||||
RuleOperatorNotContainsAnyWord RuleOperator = "not contains any word"
|
||||
RuleOperatorContainsSQLInjection RuleOperator = "contains sql injection"
|
||||
RuleOperatorContainsSQLInjectionStrictly RuleOperator = "contains sql injection strictly"
|
||||
RuleOperatorContainsXSS RuleOperator = "contains xss"
|
||||
RuleOperatorContainsXSSStrictly RuleOperator = "contains xss strictly"
|
||||
RuleOperatorInIPList RuleOperator = "in ip list"
|
||||
RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map
|
||||
RuleOperatorVersionGt RuleOperator = "version gt"
|
||||
RuleOperatorVersionLt RuleOperator = "version lt"
|
||||
RuleOperatorVersionRange RuleOperator = "version range"
|
||||
|
||||
RuleOperatorContainsBinary RuleOperator = "contains binary" // contains binary
|
||||
RuleOperatorNotContainsBinary RuleOperator = "not contains binary" // not contains binary
|
||||
|
||||
@@ -34,6 +34,9 @@ type RuleSet struct {
|
||||
actionCodes []string
|
||||
actionInstances []ActionInterface
|
||||
|
||||
hasAllowActions bool
|
||||
allowScope string
|
||||
|
||||
hasRules bool
|
||||
}
|
||||
|
||||
@@ -62,6 +65,12 @@ func (this *RuleSet) Init(waf *WAF) error {
|
||||
// action codes
|
||||
var actionCodes = []string{}
|
||||
for _, action := range this.Actions {
|
||||
if action.Code == ActionAllow {
|
||||
this.hasAllowActions = true
|
||||
if action.Options != nil {
|
||||
this.allowScope = action.Options.GetString("scope")
|
||||
}
|
||||
}
|
||||
if !lists.ContainsString(actionCodes, action.Code) {
|
||||
actionCodes = append(actionCodes, action.Code)
|
||||
}
|
||||
@@ -141,19 +150,37 @@ func (this *RuleSet) ActionCodes() []string {
|
||||
return this.actionCodes
|
||||
}
|
||||
|
||||
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
|
||||
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) PerformResult {
|
||||
if len(waf.Mode) != 0 && waf.Mode != firewallconfigs.FirewallModeDefend {
|
||||
return true, false
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
}
|
||||
}
|
||||
|
||||
var isAllowed = this.hasAllowActions
|
||||
var allowScope = this.allowScope
|
||||
var continueRequest bool
|
||||
var goNextGroup bool
|
||||
var goNextSet bool
|
||||
|
||||
// 先执行allow
|
||||
for _, instance := range this.actionInstances {
|
||||
if !instance.WillChange() {
|
||||
continueRequest = req.WAFOnAction(instance)
|
||||
if !continueRequest {
|
||||
return false, false
|
||||
return PerformResult{
|
||||
IsAllowed: isAllowed,
|
||||
AllowScope: allowScope,
|
||||
}
|
||||
}
|
||||
var performResult = instance.Perform(waf, group, this, req, writer)
|
||||
continueRequest = performResult.ContinueRequest
|
||||
goNextSet = performResult.GoNextSet
|
||||
if performResult.IsAllowed {
|
||||
isAllowed = true
|
||||
allowScope = performResult.AllowScope
|
||||
goNextGroup = performResult.GoNextGroup
|
||||
}
|
||||
_, goNextSet = instance.Perform(waf, group, this, req, writer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,13 +190,36 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req
|
||||
if instance.WillChange() {
|
||||
continueRequest = req.WAFOnAction(instance)
|
||||
if !continueRequest {
|
||||
return false, false
|
||||
return PerformResult{
|
||||
IsAllowed: isAllowed,
|
||||
AllowScope: allowScope,
|
||||
}
|
||||
}
|
||||
var performResult = instance.Perform(waf, group, this, req, writer)
|
||||
continueRequest = performResult.ContinueRequest
|
||||
goNextSet = performResult.GoNextSet
|
||||
if performResult.IsAllowed {
|
||||
isAllowed = true
|
||||
allowScope = performResult.AllowScope
|
||||
goNextGroup = performResult.GoNextGroup
|
||||
}
|
||||
return PerformResult{
|
||||
ContinueRequest: performResult.ContinueRequest,
|
||||
GoNextGroup: goNextGroup,
|
||||
GoNextSet: performResult.GoNextSet,
|
||||
IsAllowed: isAllowed,
|
||||
AllowScope: allowScope,
|
||||
}
|
||||
return instance.Perform(waf, group, this, req, writer)
|
||||
}
|
||||
}
|
||||
|
||||
return true, goNextSet
|
||||
return PerformResult{
|
||||
ContinueRequest: true,
|
||||
GoNextGroup: goNextGroup,
|
||||
GoNextSet: goNextSet,
|
||||
IsAllowed: isAllowed,
|
||||
AllowScope: allowScope,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user