Compare commits

...

27 Commits

Author SHA1 Message Date
刘祥超
a8c8d80e3b 尝试根据端口号自动纠正源站地址中的scheme 2023-06-18 18:05:28 +08:00
刘祥超
c43b6b37ea 优化代码 2023-06-18 10:01:22 +08:00
刘祥超
ac2d57d2f1 同时设置Websocket允许来源域和防盗链时,以Websocket设置为优先 2023-06-16 09:56:37 +08:00
刘祥超
83ac62cda3 缓存条件增加"强制返回区间内容"选项 2023-06-15 15:14:06 +08:00
刘祥超
c0909a2cd0 部分WAF动作输出内容时增加自定义报头 2023-06-12 18:07:07 +08:00
刘祥超
a73b9f2674 版本号改为1.2.0 2023-06-12 14:43:07 +08:00
刘祥超
3e79b71afc WAF在输出内容时也加入自定义的响应报头 2023-06-11 10:46:20 +08:00
刘祥超
d3caccbb55 上传日志时检查节点ID是否为0 2023-06-10 16:47:27 +08:00
刘祥超
f95bac8d38 在Linux上不通过交叉编译器编译时,也可以支持边缘脚本(在有商业版本源码的情况下) 2023-06-10 15:16:06 +08:00
刘祥超
41d2ab728b 手动发送数据(Send()方法)时也可以使用HTTP Header策略等 2023-06-09 14:49:32 +08:00
刘祥超
b319061e85 优化OSS相关代码 2023-06-08 17:47:04 +08:00
刘祥超
99b8686a49 修复部分测试用例 2023-06-07 21:49:42 +08:00
刘祥超
fe8c5b505a 修复一处编译问题 2023-06-07 20:30:52 +08:00
刘祥超
f88d0982ed 修复User-Agent为空时,使用了默认的Go-http-client/1.1的问题 2023-06-07 20:17:07 +08:00
刘祥超
a9389d53e1 优化代码 2023-06-07 19:30:51 +08:00
刘祥超
fc4b45fec7 HTTP服务反向代理时只把HTTP(S)源站加入到状态管理中 2023-06-07 19:28:16 +08:00
刘祥超
9b22e6cf69 初步实现对象存储源站 2023-06-07 17:27:55 +08:00
刘祥超
7bd7f7da45 允许在集群设置 -- “网站设置” 中设置节点IP访问显示的内容 2023-06-05 19:28:01 +08:00
刘祥超
b68e6517df 网站全局设置增加“强制Ln请求“选项 2023-06-05 17:06:03 +08:00
刘祥超
bbae229d08 优化Ln连接性能 2023-06-05 16:38:29 +08:00
刘祥超
c73a6cbfe8 节点监控数据增加UDP数据报速率 2023-06-04 09:58:43 +08:00
刘祥超
e869e8e4d6 缓存写入Header时忽略Strict-Transport-Security和Alt-Svc 2023-06-02 15:23:54 +08:00
刘祥超
bde4e8507f 优化代码 2023-06-02 14:23:54 +08:00
刘祥超
5ae25cffa0 连接列表增加udp支持 2023-06-02 10:54:17 +08:00
刘祥超
44b721d0d3 优化代码 2023-06-01 19:40:15 +08:00
刘祥超
a2d6b7e0a8 初步实现HTTP3 2023-06-01 17:49:06 +08:00
刘祥超
95d65481e3 优化代码 2023-05-29 20:39:08 +08:00
60 changed files with 856 additions and 380 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
*_plus.go
*_plus_test.go
*-plus.sh

View File

@@ -55,7 +55,7 @@ function build() {
cp -R "$ROOT"/pages "$DIST"/
# we support TOA on linux/amd64 only
if [ "$OS" == "linux" -a "$ARCH" == "amd64" ]
if [ "$OS" == "linux" ] && [ "$ARCH" == "amd64" ]
then
cp -R "$ROOT"/edge-toa "$DIST"
fi
@@ -114,7 +114,10 @@ function build() {
if [ ! -z $CC_PATH ]; then
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" "$ROOT"/../cmd/edge-node/main.go
else
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags $TAG -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-node/main.go
if [[ `uname` == *"Linux"* ]] && [ "$OS" == "linux" ] && [[ "$ARCH" == "amd64" || "$ARCH" == "arm64" ]] && [ "$TAG" == "plus" ]; then
BUILD_TAG="plus,script"
fi
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags $BUILD_TAG -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-node/main.go
fi
# delete hidden files

9
build/test.sh Executable file
View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
TAG=${1}
if [ -z "$TAG" ]; then
TAG="community"
fi
go test -v ../... -tags=${TAG}

View File

@@ -12,6 +12,11 @@ import (
func TestFileListDB_ListLFUItems(t *testing.T) {
var db = caches.NewFileListDB()
defer func() {
_ = db.Close()
}()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
//err := db.Open(Tea.Root + "/data/cache-index/p1/db-0.db")
if err != nil {
@@ -22,10 +27,6 @@ func TestFileListDB_ListLFUItems(t *testing.T) {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
hashList, err := db.ListLFUItems(100)
if err != nil {
t.Fatal(err)
@@ -35,25 +36,38 @@ func TestFileListDB_ListLFUItems(t *testing.T) {
func TestFileListDB_IncreaseHitAsync(t *testing.T) {
var db = caches.NewFileListDB()
defer func() {
_ = db.Close()
}()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.IncreaseHitAsync("4598e5231ba47d6ec7aa9ea640ff2eaf")
if err != nil {
t.Fatal(err)
}
// wait transaction
time.Sleep(1 * time.Second)
}
func TestFileListDB_CleanMatchKey(t *testing.T) {
var db = caches.NewFileListDB()
defer func() {
_ = db.Close()
}()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.CleanMatchKey("https://*.goedge.cn/large-text")
@@ -69,10 +83,16 @@ func TestFileListDB_CleanMatchKey(t *testing.T) {
func TestFileListDB_CleanMatchPrefix(t *testing.T) {
var db = caches.NewFileListDB()
defer func() {
_ = db.Close()
}()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.CleanMatchPrefix("https://*.goedge.cn/large-text")

View File

@@ -59,15 +59,16 @@ func TestFileListHashMap_BigInt(t *testing.T) {
func TestFileListHashMap_Load(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
var m = caches.NewFileListHashMap()
var before = time.Now()
var db = list.GetDB("abc")

View File

@@ -5,6 +5,7 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
@@ -17,6 +18,11 @@ import (
func TestFileList_Init(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -29,6 +35,11 @@ func TestFileList_Init(t *testing.T) {
func TestFileList_Add(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -59,16 +70,21 @@ func TestFileList_Add(t *testing.T) {
}
func TestFileList_Add_Many(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
err := list.Init()
if err != nil {
t.Fatal(err)
if !testutils.IsSingleTesting() {
return
}
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
var before = time.Now()
for i := 0; i < 10_000_000; i++ {
u := "https://edge.teaos.cn/123456" + strconv.Itoa(i)
@@ -92,15 +108,15 @@ func TestFileList_Add_Many(t *testing.T) {
func TestFileList_Exist(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
total, _ := list.Count()
t.Log("total:", total)
@@ -130,7 +146,7 @@ func TestFileList_Exist_Many_DB(t *testing.T) {
// 测试在多个数据库下的性能
var listSlice = []caches.ListInterface{}
for i := 1; i <= 10; i++ {
list := caches.NewFileList(Tea.Root + "/data/data" + strconv.Itoa(i))
var list = caches.NewFileList(Tea.Root + "/data/data" + strconv.Itoa(i))
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -138,6 +154,12 @@ func TestFileList_Exist_Many_DB(t *testing.T) {
listSlice = append(listSlice, list)
}
defer func() {
for _, list := range listSlice {
_ = list.Close()
}
}()
var wg = sync.WaitGroup{}
var threads = 8
wg.Add(threads)
@@ -181,15 +203,16 @@ func TestFileList_Exist_Many_DB(t *testing.T) {
func TestFileList_CleanPrefix(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
before := time.Now()
err = list.CleanPrefix("123")
if err != nil {
@@ -200,15 +223,15 @@ func TestFileList_CleanPrefix(t *testing.T) {
func TestFileList_Remove(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
list.OnRemove(func(item *caches.Item) {
t.Logf("remove %#v", item)
})
@@ -224,13 +247,15 @@ func TestFileList_Remove(t *testing.T) {
func TestFileList_Purge(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
var count = 0
_, err = list.Purge(caches.CountFileDB*2, func(hash string) error {
@@ -246,13 +271,15 @@ func TestFileList_Purge(t *testing.T) {
func TestFileList_PurgeLFU(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
err = list.IncreaseHit(stringutil.Md5("123456"))
if err != nil {
@@ -273,15 +300,16 @@ func TestFileList_PurgeLFU(t *testing.T) {
func TestFileList_Stat(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
err := list.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
}
stat, err := list.Stat(nil)
if err != nil {
t.Fatal(err)
@@ -291,6 +319,11 @@ func TestFileList_Stat(t *testing.T) {
func TestFileList_Count(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -305,7 +338,12 @@ func TestFileList_Count(t *testing.T) {
}
func TestFileList_CleanAll(t *testing.T) {
list := caches.NewFileList(Tea.Root + "/data")
var list = caches.NewFileList(Tea.Root + "/data")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -320,6 +358,11 @@ func TestFileList_CleanAll(t *testing.T) {
func TestFileList_IncreaseHit(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -333,7 +376,13 @@ func TestFileList_IncreaseHit(t *testing.T) {
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
for i := 0; i < 1000_000; i++ {
var count = 1_000_000
if !testutils.IsSingleTesting() {
count = 10
}
for i := 0; i < count; i++ {
err = list.IncreaseHit(stringutil.Md5("abc" + types.String(i)))
}
if err != nil {
@@ -344,6 +393,11 @@ func TestFileList_IncreaseHit(t *testing.T) {
func TestFileList_UpgradeV3(t *testing.T) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p43").(*caches.FileList)
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
t.Fatal(err)
@@ -363,6 +417,11 @@ func TestFileList_UpgradeV3(t *testing.T) {
func BenchmarkFileList_Exist(b *testing.B) {
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
defer func() {
_ = list.Close()
}()
err := list.Init()
if err != nil {
b.Fatal(err)

View File

@@ -2,6 +2,7 @@ package caches
import (
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/cespare/xxhash"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
@@ -107,7 +108,9 @@ func TestMemoryList_Purge_Large_List(t *testing.T) {
})
}
time.Sleep(1 * time.Hour)
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Hour)
}
}
func TestMemoryList_Stat(t *testing.T) {
@@ -255,9 +258,11 @@ func TestMemoryList_GC(t *testing.T) {
//runtime.GC()
t.Log("gc cost:", time.Since(before).Seconds()*1000, "ms")
timeout := time.NewTimer(2 * time.Minute)
<-timeout.C
t.Log("2 minutes passed")
if testutils.IsSingleTesting() {
timeout := time.NewTimer(2 * time.Minute)
<-timeout.C
t.Log("2 minutes passed")
time.Sleep(30 * time.Minute)
time.Sleep(30 * time.Minute)
}
}

View File

@@ -4,6 +4,7 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
@@ -23,7 +24,9 @@ func TestNewOpenFileCache_Close(t *testing.T) {
cache.Get("d.txt")
cache.Close("a.txt")
time.Sleep(100 * time.Second)
if testutils.IsSingleTesting() {
time.Sleep(100 * time.Second)
}
}
func TestNewOpenFileCache_CloseAll(t *testing.T) {

View File

@@ -8,21 +8,29 @@ import (
)
func TestFileReader(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
}
_, path, _ := storage.keyPath("my-key")
fp, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
t.Log("file '" + path + "' not exists")
return
}
t.Fatal(err)
}
defer func() {
@@ -58,6 +66,10 @@ func TestFileReader_ReadHeader(t *testing.T) {
var path = "/Users/WorkSpace/EdgeProject/EdgeCache/p43/12/6b/126bbed90fc80f2bdfb19558948b0d49.cache"
fp, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
t.Log("'" + path + "' not exists")
return
}
t.Fatal(err)
}
defer func() {
@@ -66,6 +78,11 @@ func TestFileReader_ReadHeader(t *testing.T) {
var reader = NewFileReader(fp)
err = reader.Init()
if err != nil {
if os.IsNotExist(err) {
t.Log("file '" + path + "' not exists")
return
}
t.Fatal(err)
}
var buf = make([]byte, 16*1024)
@@ -79,13 +96,16 @@ func TestFileReader_ReadHeader(t *testing.T) {
}
func TestFileReader_Range(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -109,6 +129,10 @@ func TestFileReader_Range(t *testing.T) {
fp, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
t.Log("'" + path + "' not exists")
return
}
t.Fatal(err)
}
defer func() {

View File

@@ -899,7 +899,10 @@ func (this *FileStorage) Stop() {
memoryStorage.Stop()
})
_ = this.list.Reset()
if this.list != nil {
_ = this.list.Reset()
}
if this.purgeTicker != nil {
this.purgeTicker.Stop()
}
@@ -907,7 +910,9 @@ func (this *FileStorage) Stop() {
this.hotTicker.Stop()
}
_ = this.list.Close()
if this.list != nil {
_ = this.list.Close()
}
var openFileCache = this.openFileCache
if openFileCache != nil {

View File

@@ -18,7 +18,7 @@ import (
)
func TestFileStorage_Init(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
@@ -26,6 +26,8 @@ func TestFileStorage_Init(t *testing.T) {
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -44,13 +46,16 @@ func TestFileStorage_Init(t *testing.T) {
}
func TestFileStorage_OpenWriter(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -95,6 +100,9 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -123,13 +131,16 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
}
func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -188,13 +199,16 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
}
func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -243,13 +257,16 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
}
func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -299,13 +316,16 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
}
func TestFileStorage_Read(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -335,13 +355,16 @@ func TestFileStorage_Read(t *testing.T) {
}
func TestFileStorage_Read_HTTP_Response(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -388,13 +411,16 @@ func TestFileStorage_Read_HTTP_Response(t *testing.T) {
}
func TestFileStorage_Read_NotFound(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -421,13 +447,16 @@ func TestFileStorage_Read_NotFound(t *testing.T) {
}
func TestFileStorage_Delete(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -440,13 +469,16 @@ func TestFileStorage_Delete(t *testing.T) {
}
func TestFileStorage_Stat(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -465,13 +497,16 @@ func TestFileStorage_Stat(t *testing.T) {
}
func TestFileStorage_CleanAll(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -496,13 +531,16 @@ func TestFileStorage_CleanAll(t *testing.T) {
}
func TestFileStorage_Stop(t *testing.T) {
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -518,6 +556,9 @@ func TestFileStorage_DecodeFile(t *testing.T) {
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
t.Fatal(err)
@@ -528,6 +569,9 @@ func TestFileStorage_DecodeFile(t *testing.T) {
func TestFileStorage_RemoveCacheFile(t *testing.T) {
var storage = NewFileStorage(nil)
defer storage.Stop()
t.Log(storage.removeCacheFile("/Users/WorkSpace/EdgeProject/EdgeCache/p43/15/7e/157eba0dfc6dfb6fbbf20b1f9e584674.cache"))
}
@@ -536,13 +580,16 @@ func BenchmarkFileStorage_Read(b *testing.B) {
_ = utils.SetRLimit(1024 * 1024)
storage := NewFileStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewFileStorage(&serverconfigs.HTTPCachePolicy{
Id: 1,
IsOn: true,
Options: map[string]interface{}{
"dir": Tea.Root + "/caches",
},
})
defer storage.Stop()
err := storage.Init()
if err != nil {
b.Fatal(err)

View File

@@ -3,6 +3,7 @@
package conns
import (
"github.com/iwind/TeaGo/types"
"net"
"sync"
)
@@ -10,14 +11,14 @@ import (
var SharedMap = NewMap()
type Map struct {
m map[string]map[int]net.Conn // ip => { port => Conn }
m map[string]map[string]net.Conn // ip => { network_port => Conn }
locker sync.RWMutex
}
func NewMap() *Map {
return &Map{
m: map[string]map[int]net.Conn{},
m: map[string]map[string]net.Conn{},
}
}
@@ -25,21 +26,19 @@ func (this *Map) Add(conn net.Conn) {
if conn == nil {
return
}
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
key, ip, ok := this.connAddr(conn)
if !ok {
return
}
var ip = tcpAddr.IP.String()
var port = tcpAddr.Port
this.locker.Lock()
defer this.locker.Unlock()
connMap, ok := this.m[ip]
if !ok {
this.m[ip] = map[int]net.Conn{port: conn}
this.m[ip] = map[string]net.Conn{key: conn}
} else {
connMap[port] = conn
connMap[key] = conn
}
}
@@ -47,14 +46,11 @@ func (this *Map) Remove(conn net.Conn) {
if conn == nil {
return
}
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
key, ip, ok := this.connAddr(conn)
if !ok {
return
}
var ip = tcpAddr.IP.String()
var port = tcpAddr.Port
this.locker.Lock()
defer this.locker.Unlock()
@@ -62,7 +58,7 @@ func (this *Map) Remove(conn net.Conn) {
if !ok {
return
}
delete(connMap, port)
delete(connMap, key)
if len(connMap) == 0 {
delete(this.m, ip)
@@ -121,3 +117,24 @@ func (this *Map) AllConns() []net.Conn {
return result
}
func (this *Map) connAddr(conn net.Conn) (key string, ip string, ok bool) {
if conn == nil {
return
}
var addr = conn.RemoteAddr()
switch realAddr := addr.(type) {
case *net.TCPAddr:
return addr.Network() + types.String(realAddr.Port), realAddr.IP.String(), true
case *net.UDPAddr:
return addr.Network() + types.String(realAddr.Port), realAddr.IP.String(), true
default:
var s = addr.String()
host, port, err := net.SplitHostPort(s)
if err != nil {
return
}
return addr.Network() + port, host, true
}
}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.1.0"
Version = "1.2.0"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -30,7 +30,6 @@ func TestIPListManager_check(t *testing.T) {
func TestIPListManager_loop(t *testing.T) {
manager := NewIPListManager()
manager.Start()
manager.pageSize = 10
err := manager.loop()
if err != nil {
t.Fatal(err)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/rands"
"testing"
@@ -79,6 +80,10 @@ func TestTask_Add(t *testing.T) {
}
func TestTask_Add_Many(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,

View File

@@ -61,11 +61,16 @@ func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool
isTLS: isTLS,
isHTTP: isHTTP,
isLO: strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:"),
isNoStat: connutils.IsNoStatConn(rawConn.RemoteAddr().String()),
isNoStat: connutils.IsNoStatConn(remoteAddr),
isInAllowList: isInAllowList,
createdAt: fasttime.Now().Unix(),
}
if existsLnNodeIP(conn.RawIP()) {
conn.SetIsPersistent(true)
}
// 超时等设置
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil {
var performanceConfig = globalServerConfig.Performance
@@ -129,7 +134,7 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
}
// 忽略白名单和局域网
if this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
if !this.isPersistent && this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
@@ -165,8 +170,7 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
}
// 设置写超时时间
if this.autoWriteTimeout {
// TODO L2 -> L1 写入时不限制时间
if !this.isPersistent && this.autoWriteTimeout {
var timeoutSeconds = len(b) / 1024
if timeoutSeconds < 3 {
timeoutSeconds = 3

View File

@@ -54,12 +54,13 @@ func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
goNext = true
// 检查服务相关IP黑名单
if serverId > 0 && len(this.rawIP) > 0 {
var rawIP = this.RawIP()
if serverId > 0 && len(rawIP) > 0 {
// 是否在白名单中
ok, _, expiresAt := iplibrary.AllowIP(this.rawIP, serverId)
ok, _, expiresAt := iplibrary.AllowIP(rawIP, serverId)
if !ok {
_ = this.rawConn.Close()
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
firewalls.DropTemporaryTo(rawIP, expiresAt)
return false
}
}
@@ -123,8 +124,8 @@ func (this *BaseClientConn) TCPConn() (tcpConn *net.TCPConn, ok bool) {
switch conn := this.rawConn.(type) {
case *tls.Conn:
var internalConn = conn.NetConn()
clientConn, ok := internalConn.(*ClientConn)
if ok {
clientConn, isClientConn := internalConn.(*ClientConn)
if isClientConn {
return clientConn.TCPConn()
}
tcpConn, ok = internalConn.(*net.TCPConn)

View File

@@ -46,6 +46,7 @@ type HTTPRequest struct {
ServerAddr string // 实际启动的服务器监听地址
IsHTTP bool
IsHTTPS bool
IsHTTP3 bool
// 共享参数
nodeConfig *nodeconfigs.NodeConfig
@@ -479,6 +480,17 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
// remote addr
if web.RemoteAddr != nil && (web.RemoteAddr.IsPrior || isTop) && web.RemoteAddr.IsOn {
this.web.RemoteAddr = web.RemoteAddr
// check if from proxy
if len(this.web.RemoteAddr.Value) > 0 && this.web.RemoteAddr.Value != "${rawRemoteAddr}" {
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil {
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsPersistent(true)
}
}
}
}
// charset
@@ -1698,8 +1710,8 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
}
}
// 处理自定义Response Header
func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
// ProcessResponseHeaders 处理自定义Response Header
func (this *HTTPRequest) ProcessResponseHeaders(responseHeader http.Header, statusCode int) {
// 删除/添加/替换Header
// TODO 实现AddTrailers
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {
@@ -1828,6 +1840,11 @@ func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, stat
this.ReqServer.HTTPS.SSLPolicy.HSTS.Match(this.ReqHost) {
responseHeader.Set(this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderValue())
}
// HTTP/3
if this.IsHTTPS && !this.IsHTTP3 && this.ReqServer.SupportsHTTP3() {
this.processHTTP3Headers(responseHeader)
}
}
// 添加错误信息
@@ -1897,7 +1914,7 @@ func (this *HTTPRequest) canIgnore(err error) bool {
// 检查连接是否已关闭
func (this *HTTPRequest) isConnClosed() bool {
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return true
}

View File

@@ -40,7 +40,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var addStatusHeader = this.web.Cache.AddStatusHeader
if addStatusHeader {
defer func() {
cacheStatus := this.varMapping["cache.status"]
var cacheStatus = this.varMapping["cache.status"]
if cacheStatus != "HIT" {
this.writer.Header().Set("X-Cache", cacheStatus)
}
@@ -48,7 +48,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
// 检查服务独立的缓存条件
refType := ""
var refType = ""
for _, cacheRef := range this.web.Cache.CacheRefs {
if !cacheRef.IsOn {
continue
@@ -131,7 +131,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.key"] = key
// 读取缓存
storage := caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
var storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
if storage == nil {
this.cacheRef = nil
return
@@ -241,16 +241,19 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
reader, err = storage.OpenReader(key, useStale, false)
if err != nil && this.cacheRef.AllowPartialContent {
// 尝试读取分片的缓存内容
if len(rangeHeader) == 0 {
if len(rangeHeader) == 0 && this.cacheRef.ForcePartialContent {
// 默认读取开头
rangeHeader = "bytes=0-"
}
pReader, ranges := this.tryPartialReader(storage, key, useStale, rangeHeader)
if pReader != nil {
isPartialCache = true
reader = pReader
partialRanges = ranges
err = nil
if len(rangeHeader) > 0 {
pReader, ranges := this.tryPartialReader(storage, key, useStale, rangeHeader)
if pReader != nil {
isPartialCache = true
reader = pReader
partialRanges = ranges
err = nil
}
}
}
@@ -301,13 +304,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.writer.SetSentHeaderBytes(reader.HeaderSize())
var headerPool = this.bytePool(reader.HeaderSize())
var headerBuf = headerPool.Get()
err = reader.ReadHeader(headerBuf, func(n int) (goNext bool, err error) {
err = reader.ReadHeader(headerBuf, func(n int) (goNext bool, readErr error) {
headerData = append(headerData, headerBuf[:n]...)
for {
nIndex := bytes.Index(headerData, []byte{'\n'})
var nIndex = bytes.Index(headerData, []byte{'\n'})
if nIndex >= 0 {
row := headerData[:nIndex]
spaceIndex := bytes.Index(row, []byte{':'})
var row = headerData[:nIndex]
var spaceIndex = bytes.Index(row, []byte{':'})
if spaceIndex <= 0 {
return false, errors.New("invalid header '" + string(row) + "'")
}
@@ -375,7 +378,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-None-Match
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -387,7 +390,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-Modified-Since
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -396,7 +399,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
this.processResponseHeaders(this.writer.Header(), reader.Status())
this.ProcessResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间
@@ -425,7 +428,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if supportRange {
if len(rangeHeader) > 0 {
if fileSize == 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -433,7 +436,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -442,7 +445,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -460,9 +463,9 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
err = reader.ReadBodyRange(bodyBuf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(bodyBuf[:n])
if err != nil {
err = reader.ReadBodyRange(bodyBuf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, readErr error) {
_, readErr = this.writer.Write(bodyBuf[:n])
if readErr != nil {
return false, errWritingToClient
}
return true, nil
@@ -472,7 +475,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -485,7 +488,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var boundary = httpRequestGenBoundary()
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
respHeader.Del("Content-Length")
contentType := respHeader.Get("Content-Type")
var contentType = respHeader.Get("Content-Type")
this.writer.WriteHeader(http.StatusPartialContent)
@@ -516,9 +519,9 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
err := reader.ReadBodyRange(bodyBuf, r.Start(), r.End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(bodyBuf[:n])
if err != nil {
err = reader.ReadBodyRange(bodyBuf, r.Start(), r.End(), func(n int) (goNext bool, readErr error) {
_, readErr = this.writer.Write(bodyBuf[:n])
if readErr != nil {
return false, errWritingToClient
}
return true, nil

View File

@@ -57,7 +57,7 @@ func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage s
return "${" + varName + "}"
})
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
@@ -110,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
return "${" + varName + "}"
})
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))

View File

@@ -197,7 +197,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)

View File

@@ -54,7 +54,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}
@@ -96,7 +96,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
} else { // 精准匹配
@@ -119,7 +119,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}
@@ -155,7 +155,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
// 参数
var qIndex = strings.Index(this.uri, "?")
@@ -211,7 +211,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}

View File

@@ -0,0 +1,10 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "net/http"
func (this *HTTPRequest) processHTTP3Headers(respHeader http.Header) {
// stub
}

View File

@@ -12,6 +12,10 @@ const (
LNExpiresHeader = "X-Edge-Ln-Expires"
)
func existsLnNodeIP(nodeIP string) bool {
return false
}
func (this *HTTPRequest) checkLnRequest() bool {
return false
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"net"
"net/http"
"time"
)
@@ -32,7 +33,14 @@ func (this *HTTPRequest) doMismatch() {
}
// 根据配置进行相应的处理
if sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
// 是否正在访问IP
if globalServerConfig.HTTPAll.NodeIPShowPage && net.ParseIP(this.ReqHost) != nil {
_, _ = this.writer.WriteString(globalServerConfig.HTTPAll.NodeIPPageHTML)
return
}
// 检查cc
// TODO 可以在管理端配置是否开启以及最多尝试次数
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
@@ -47,7 +55,7 @@ func (this *HTTPRequest) doMismatch() {
}
// 处理当前连接
var httpAllConfig = sharedNodeConfig.GlobalServerConfig.HTTPAll
var httpAllConfig = globalServerConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Code == "page" {
if mismatchAction.Options != nil {

View File

@@ -0,0 +1,15 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
)
func (this *HTTPRequest) doOSSOrigin(origin *serverconfigs.OriginConfig) (resp *http.Response, goNext bool, errorCode string, err error) {
// stub
return nil, false, "", errors.New("not implemented")
}

View File

@@ -9,11 +9,8 @@ import (
"github.com/iwind/TeaGo/logs"
"net/http"
"os"
"regexp"
)
var urlPrefixRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://")
// 请求特殊页面
func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
if len(this.web.Pages) == 0 {
@@ -49,7 +46,7 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
for _, page := range pages {
if page.Match(status) {
if len(page.BodyType) == 0 || page.BodyType == shared.BodyTypeURL {
if urlPrefixRegexp.MatchString(page.URL) {
if urlSchemeRegexp.MatchString(page.URL) {
var newStatus = page.NewStatus
if newStatus <= 0 {
newStatus = status
@@ -87,11 +84,11 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
@@ -126,11 +123,11 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}

View File

@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))

View File

@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
}
var newURL = "https://" + host + this.RawReq.RequestURI
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
return true

View File

@@ -12,13 +12,29 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
return
}
var origin = this.RawReq.Header.Get("Origin")
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
// 处理用到Origin的特殊功能
if this.web.Referers.CheckOrigin && len(origin) > 0 {
// 处理Websocket
if this.web.Websocket != nil && this.web.Websocket.IsOn && this.RawReq.Header.Get("Upgrade") == "websocket" {
originHost, _ := httpParseHost(origin)
if len(originHost) > 0 && this.web.Websocket.MatchOrigin(originHost) {
return
}
}
}
var refererURL = this.RawReq.Header.Get("Referer")
if len(refererURL) == 0 && this.web.Referers.CheckOrigin {
var origin = this.RawReq.Header.Get("Origin")
if len(origin) > 0 && origin != "null" {
refererURL = "https://" + origin // 因为Origin都只有域名部分所以为了下面的URL 分析需要加上https://
if urlSchemeRegexp.MatchString(origin) {
refererURL = origin
} else {
refererURL = "https://" + origin
}
}
}

View File

@@ -66,19 +66,21 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 二级节点
var hasMultipleLnNodes = false
if this.cacheRef != nil {
if this.cacheRef != nil || (this.nodeConfig != nil && this.nodeConfig.GlobalServerConfig != nil && this.nodeConfig.GlobalServerConfig.HTTPAll.ForceLnRequest) {
origin, lnNodeId, hasMultipleLnNodes = this.getLnOrigin(failedLnNodeIds)
if origin != nil {
// 强制变更原来访问的域名
requestHost = this.ReqHost
}
// 回源Header中去除If-None-Match和If-Modified-Since
if !this.cacheRef.EnableIfNoneMatch {
this.DeleteHeader("If-None-Match")
}
if !this.cacheRef.EnableIfModifiedSince {
this.DeleteHeader("If-Modified-Since")
if this.cacheRef != nil {
// 回源Header中去除If-None-Match和If-Modified-Since
if !this.cacheRef.EnableIfNoneMatch {
this.DeleteHeader("If-None-Match")
}
if !this.cacheRef.EnableIfModifiedSince {
this.DeleteHeader("If-Modified-Since")
}
}
}
@@ -115,14 +117,20 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
requestHostHasVariables = origin.RequestHostHasVariables()
}
// 处理OSS
var isHTTPOrigin = origin.OSS == nil
// 处理Scheme
if origin.Addr == nil {
if isHTTPOrigin && origin.Addr == nil {
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", err.Error())
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
return
}
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
if isHTTPOrigin {
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
}
// StripPrefix
if len(stripPrefix) > 0 {
@@ -159,63 +167,66 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
this.uri = utils.CleanPath(this.uri)
}
// 获取源站地址
var originAddr = origin.Addr.PickAddress()
if origin.Addr.HostHasVariables() {
originAddr = this.Format(originAddr)
}
// 端口跟随
if origin.FollowPort {
var originHostIndex = strings.Index(originAddr, ":")
if originHostIndex < 0 {
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
return
var originAddr = ""
if isHTTPOrigin {
// 获取源站地址
originAddr = origin.Addr.PickAddress()
if origin.Addr.HostHasVariables() {
originAddr = this.Format(originAddr)
}
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
}
this.originAddr = originAddr
// RequestHost
if len(requestHost) > 0 {
if requestHostHasVariables {
this.RawReq.Host = this.Format(requestHost)
// 端口跟随
if origin.FollowPort {
var originHostIndex = strings.Index(originAddr, ":")
if originHostIndex < 0 {
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
return
}
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
}
this.originAddr = originAddr
// RequestHost
if len(requestHost) > 0 {
if requestHostHasVariables {
this.RawReq.Host = this.Format(requestHost)
} else {
this.RawReq.Host = requestHost
}
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin {
// 源站主机名
var hostname = originAddr
if origin.Addr.Protocol.IsHTTPFamily() {
hostname = strings.TrimSuffix(hostname, ":80")
} else if origin.Addr.Protocol.IsHTTPSFamily() {
hostname = strings.TrimSuffix(hostname, ":443")
}
this.RawReq.Host = hostname
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else {
this.RawReq.Host = requestHost
}
this.RawReq.URL.Host = this.ReqHost
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin {
// 源站主机名
var hostname = originAddr
if origin.Addr.Protocol.IsHTTPFamily() {
hostname = strings.TrimSuffix(hostname, ":80")
} else if origin.Addr.Protocol.IsHTTPSFamily() {
hostname = strings.TrimSuffix(hostname, ":443")
}
this.RawReq.Host = hostname
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else {
this.RawReq.URL.Host = this.ReqHost
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
this.RawReq.URL.Host = utils.ParseAddrHost(this.RawReq.URL.Host)
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
this.RawReq.URL.Host = utils.ParseAddrHost(this.RawReq.URL.Host)
}
}
}
@@ -241,34 +252,72 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
// 判断是否为Websocket请求
if this.RawReq.Header.Get("Upgrade") == "websocket" {
if isHTTPOrigin && this.RawReq.Header.Get("Upgrade") == "websocket" {
shouldRetry = this.doWebsocket(requestHost, isLastRetry)
return
}
// 获取请求客户端
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
if err != nil {
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
var resp *http.Response
var requestErr error
var requestErrCode string
if isHTTPOrigin { // 普通HTTP(S)源站
// 修复空User-Agent问题
_, existsUserAgent := this.RawReq.Header["User-Agent"]
if !existsUserAgent {
this.RawReq.Header["User-Agent"] = []string{""}
}
// 获取请求客户端
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
if err != nil {
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
return
}
// 尝试自动纠正源站地址中的scheme
if this.RawReq.URL.Scheme == "http" && strings.HasSuffix(originAddr, ":443") {
this.RawReq.URL.Scheme = "https"
} else if this.RawReq.URL.Scheme == "https" && strings.HasSuffix(originAddr, ":80") {
this.RawReq.URL.Scheme = "http"
}
// 开始请求
resp, requestErr = client.Do(this.RawReq)
} else if origin.OSS != nil { // OSS源站
var goNext bool
resp, goNext, requestErrCode, requestErr = this.doOSSOrigin(origin)
if requestErr == nil {
if resp == nil || !goNext {
return
}
}
} else {
this.writeCode(http.StatusBadGateway, "The type of origin site has not been supported", "设置的源站类型尚未支持")
return
}
// 开始请求
resp, err := client.Do(this.RawReq)
if err != nil {
if requestErr != nil {
// 客户端取消请求,则不提示
httpErr, ok := err.(*url.Error)
httpErr, ok := requestErr.(*url.Error)
if !ok {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
if isHTTPOrigin {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
}
if len(requestErrCode) > 0 {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site (error code: "+requestErrCode+")", "源站读取失败(错误代号:"+requestErrCode+"", true)
} else {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+requestErr.Error())
} else if httpErr.Err != context.Canceled {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
if isHTTPOrigin {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
}
// 是否需要重试
if (originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) && !isLastRetry {
@@ -280,21 +329,21 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
if httpErr.Err != io.EOF {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
}
return
}
if httpErr.Timeout() {
this.write50x(err, http.StatusGatewayTimeout, "Read origin site timeout", "源站读取超时", true)
this.write50x(requestErr, http.StatusGatewayTimeout, "Read origin site timeout", "源站读取超时", true)
} else if httpErr.Temporary() {
this.write50x(err, http.StatusServiceUnavailable, "Origin site unavailable now", "源站当前不可用", true)
this.write50x(requestErr, http.StatusServiceUnavailable, "Origin site unavailable now", "源站当前不可用", true)
} else {
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
if httpErr.Err != io.EOF {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
}
} else {
// 是否为客户端方面的错误
@@ -314,7 +363,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
if !isClientError {
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
}
if resp != nil && resp.Body != nil {
@@ -337,7 +386,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
this.originStatus = int32(resp.StatusCode)
// 恢复源站状态
if !origin.IsOk {
if !origin.IsOk && isHTTPOrigin {
SharedOriginStateManager.Success(origin, func() {
this.reverseProxy.ResetScheduling()
})
@@ -346,7 +395,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// WAF对出站进行检查
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn {
if this.doWAFResponse(resp) {
err = resp.Body.Close()
err := resp.Body.Close()
if err != nil {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing Error (WAF): "+err.Error())
}
@@ -356,7 +405,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 特殊页面
if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
err = resp.Body.Close()
err := resp.Body.Close()
if err != nil {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error (Page): "+err.Error())
}
@@ -409,7 +458,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"
@@ -437,6 +486,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 输出到客户端
var pool = this.bytePool(resp.ContentLength)
var buf = pool.Get()
var err error
if shouldAutoFlush {
for {
n, readErr := resp.Body.Read(buf)

View File

@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
// 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 {
this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
this.ProcessResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
}
return true

View File

@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
set, ok := httpRequestParseRangeHeader(contentRange)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存
if len(ranges) > 0 {
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -19,7 +19,7 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.BodyType) == 0 || shutdown.BodyType == shared.BodyTypeURL {
// URL
if urlPrefixRegexp.MatchString(shutdown.URL) {
if urlSchemeRegexp.MatchString(shutdown.URL) {
this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status, true)
return
}
@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.URL) == 0 {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, err := this.writer.WriteString("The site have been shutdown.")
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
buf := utils.BytePool1k.Get()
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
} else if shutdown.BodyType == shared.BodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
this.tags = append(this.tags, "bandwidth")
var statusCode = 509
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 {

View File

@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
// Header
if statusCode <= 0 {
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
} else {
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
}
if supportVariables {

View File

@@ -9,6 +9,7 @@ import (
"github.com/iwind/TeaGo/types"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
@@ -22,6 +23,9 @@ var spiderRegexp = regexp.MustCompile(`(?i)(python|pycurl|http-client|httpclient
// 内容范围正则,其中的每个括号里的内容都在被引用,不能轻易修改
var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)`)
// URL协议前缀
var urlSchemeRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://")
// 分解Range
func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) {
// 参考RFChttps://tools.ietf.org/html/rfc7233
@@ -222,3 +226,16 @@ func httpRedirect(writer http.ResponseWriter, req *http.Request, url string, cod
http.Redirect(writer, req, url, code)
}
// 分析URL中的Host部分
func httpParseHost(urlString string) (host string, err error) {
if !urlSchemeRegexp.MatchString(urlString) {
urlString = "https://" + urlString
}
u, err := url.Parse(urlString)
if err != nil && u != nil {
return "", err
}
return u.Host, nil
}

View File

@@ -145,6 +145,23 @@ func TestHTTPRequest_httpRequestNextId_Concurrent(t *testing.T) {
a.IsTrue(countDuplicated == 0)
}
func TestHTTPParseURL(t *testing.T) {
for _, s := range []string{
"",
"null",
"example.com",
"https://example.com",
"https://example.com/hello",
} {
host, err := httpParseHost(s)
if err == nil {
t.Log(s, "=>", host)
} else {
t.Log(s, "=>")
}
}
}
func BenchmarkHTTPRequest_httpRequestNextId(b *testing.B) {
runtime.GOMAXPROCS(1)

View File

@@ -137,7 +137,7 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
// 将响应写回客户端

View File

@@ -362,7 +362,10 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
// 写入Header
var headerBuf = utils.SharedBufferPool.Get()
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
if k == "Set-Cookie" ||
k == "Strict-Transport-Security" ||
k == "Alt-Svc" ||
(this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {
@@ -690,7 +693,10 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
// 写入Header
var headerBuffer = utils.SharedBufferPool.Get()
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
if k == "Set-Cookie" ||
k == "Strict-Transport-Security" ||
k == "Alt-Svc" ||
(this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {
@@ -837,6 +843,14 @@ func (this *HTTPWriter) WriteHeader(statusCode int) {
// Send 直接发送内容,并终止请求
func (this *HTTPWriter) Send(status int, body string) {
this.req.ProcessResponseHeaders(this.Header(), status)
// content-length
_, hasContentLength := this.Header()["Content-Length"]
if !hasContentLength {
this.Header()["Content-Length"] = []string{types.String(len(body))}
}
this.WriteHeader(status)
_, _ = this.WriteString(body)
this.isFinished = true
@@ -882,6 +896,7 @@ func (this *HTTPWriter) SendResp(resp *http.Response) (int64, error) {
for k, v := range resp.Header {
this.SetHeader(k, v)
}
this.WriteHeader(resp.StatusCode)
var bufPool = this.req.bytePool(resp.ContentLength)
var buf = bufPool.Get()
@@ -1018,7 +1033,9 @@ func (this *HTTPWriter) finishWebP() {
if webpCacheWriter != nil {
// 写入Header
for k, v := range this.Header() {
if k == "Set-Cookie" {
if k == "Set-Cookie" ||
k == "Strict-Transport-Security" ||
k == "Alt-Svc" {
continue
}
@@ -1237,7 +1254,10 @@ func (this *HTTPWriter) finishRequest() {
// 计算Header长度
func (this *HTTPWriter) calculateHeaderLength() (result int) {
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
if k == "Set-Cookie" ||
k == "Strict-Transport-Security" ||
k == "Alt-Svc" ||
(this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {

View File

@@ -116,7 +116,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
return nil, nil, errors.New("no tls server name found")
}
// 通过代理服务域名配置匹配
// 通过网站域名配置匹配
server, _ := this.findNamedServer(domain)
if server == nil {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
@@ -138,7 +138,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
// 此功能仅为了兼容以往版本v1.0.4),不应该作为常态启用
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
for _, searchingServer := range group.Servers() {
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
continue
@@ -174,19 +174,26 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
return
}
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly
if sharedNodeConfig.GlobalServerConfig != nil &&
len(sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) || (sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP && net.ParseIP(name) != nil)) {
var defaultDomain = sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
if serverConfig != nil {
if globalServerConfig != nil &&
len(globalServerConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) || (globalServerConfig.HTTPAll.AllowNodeIP && net.ParseIP(name) != nil)) {
if globalServerConfig.HTTPAll.AllowNodeIP &&
globalServerConfig.HTTPAll.NodeIPShowPage &&
net.ParseIP(name) != nil {
return
} else {
var defaultDomain = globalServerConfig.HTTPAll.DefaultDomain
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
if serverConfig != nil {
return
}
}
}
if matchDomainStrictly && !configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP || net.ParseIP(name) == nil) {
if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || net.ParseIP(name) == nil) {
return
}

View File

@@ -32,6 +32,7 @@ type HTTPListener struct {
addr string
isHTTP bool
isHTTPS bool
isHTTP3 bool
httpServer *http.Server
}
@@ -199,6 +200,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
ServerAddr: this.addr,
IsHTTP: this.isHTTP,
IsHTTPS: this.isHTTPS,
IsHTTP3: this.isHTTP3,
nodeConfig: sharedNodeConfig,
}

View File

@@ -36,9 +36,11 @@ func init() {
// ListenerManager 端口监听管理器
type ListenerManager struct {
listenersMap map[string]*Listener // addr => *Listener
locker sync.Mutex
lastConfig *nodeconfigs.NodeConfig
listenersMap map[string]*Listener // addr => *Listener
http3Listener *HTTPListener
locker sync.Mutex
lastConfig *nodeconfigs.NodeConfig
retryListenerMap map[string]*Listener // 需要重试的监听器 addr => Listener
ticker *time.Ticker
@@ -73,7 +75,7 @@ func NewListenerManager() *ListenerManager {
}
// Start 启动监听
func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
func (this *ListenerManager) Start(nodeConfig *nodeconfigs.NodeConfig) error {
this.locker.Lock()
defer this.locker.Unlock()
@@ -84,12 +86,12 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
/**if this.lastConfig != nil && this.lastConfig.Version == node.Version {
return nil
}**/
this.lastConfig = node
this.lastConfig = nodeConfig
// 所有的新地址
groupAddrs := []string{}
availableServerGroups := node.AvailableGroups()
if !node.IsOn {
var groupAddrs = []string{}
var availableServerGroups = nodeConfig.AvailableGroups()
if !nodeConfig.IsOn {
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
}
@@ -98,13 +100,13 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
}
for _, group := range availableServerGroups {
addr := group.FullAddr()
var addr = group.FullAddr()
groupAddrs = append(groupAddrs, addr)
}
// 停掉老的
for listenerKey, listener := range this.listenersMap {
addr := listener.FullAddr()
var addr = listener.FullAddr()
if !lists.ContainsString(groupAddrs, addr) {
remotelogs.Println("LISTENER_MANAGER", "close '"+addr+"'")
_ = listener.Close()
@@ -115,7 +117,7 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
// 启动新的或修改老的
for _, group := range availableServerGroups {
addr := group.FullAddr()
var addr = group.FullAddr()
listener, ok := this.listenersMap[addr]
if ok {
// 不需要打印reload信息防止日志数量过多
@@ -129,7 +131,7 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
// 放入到重试队列中
this.retryListenerMap[addr] = listener
firstServer := group.FirstServer()
var firstServer = group.FirstServer()
if firstServer == nil {
remotelogs.Error("LISTENER_MANAGER", err.Error())
} else {
@@ -167,10 +169,15 @@ func (this *ListenerManager) TotalActiveConnections() int {
this.locker.Lock()
defer this.locker.Unlock()
total := 0
var total = 0
for _, listener := range this.listenersMap {
total += listener.listener.CountActiveConnections()
}
if this.http3Listener != nil {
total += this.http3Listener.CountActiveConnections()
}
return total
}
@@ -239,6 +246,17 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
return
}
// HTTP/3相关端口
var http3Ports = sharedNodeConfig.FindHTTP3Ports()
if len(http3Ports) > 0 {
for _, port := range http3Ports {
var groupAddr = "udp://:" + types.String(port)
if !lists.ContainsString(groupAddrs, groupAddr) {
groupAddrs = append(groupAddrs, groupAddr)
}
}
}
// 组合端口号
var portStrings = []string{}
var udpPorts = []int{}
@@ -272,7 +290,9 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
if newPortStrings == this.lastPortStrings {
return
}
this.locker.Lock()
this.lastPortStrings = newPortStrings
this.locker.Unlock()
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
defer func() {
@@ -284,8 +304,10 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
var udpPortRanges = utils.MergePorts(udpPorts)
defer func() {
this.locker.Lock()
this.lastTCPPortRanges = tcpPortRanges
this.lastUDPPortRanges = udpPortRanges
this.locker.Unlock()
}()
// 删除老的不存在的端口
@@ -321,3 +343,28 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
}
func (this *ListenerManager) reloadFirewalld() {
this.locker.Lock()
defer this.locker.Unlock()
var nodeConfig = sharedNodeConfig
// 所有的新地址
var groupAddrs = []string{}
var availableServerGroups = nodeConfig.AvailableGroups()
if !nodeConfig.IsOn {
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
}
if len(availableServerGroups) == 0 {
remotelogs.Println("LISTENER_MANAGER", "no available servers to startup")
}
for _, group := range availableServerGroups {
var addr = group.FullAddr()
groupAddrs = append(groupAddrs, addr)
}
go this.addToFirewalld(groupAddrs)
}

View File

@@ -693,6 +693,7 @@ func (this *Node) listenSock() error {
var lastReadAt int64
var lastWriteAt int64
var lastErrString = ""
var protocol = "tcp"
clientConn, ok := conn.(*ClientConn)
if ok {
createdAt = clientConn.CreatedAt()
@@ -703,6 +704,8 @@ func (this *Node) listenSock() error {
if lastErr != nil {
lastErrString = lastErr.Error()
}
} else {
protocol = "udp"
}
var age int64 = -1
var lastReadAge int64 = -1
@@ -719,6 +722,7 @@ func (this *Node) listenSock() error {
}
connMaps = append(connMaps, maps.Map{
"protocol": protocol,
"addr": conn.RemoteAddr().String(),
"age": age,
"readAge": lastReadAge,

View File

@@ -33,7 +33,10 @@ type NodeStatusExecutor struct {
cpuLogicalCount int
cpuPhysicalCount int
lastIOCounterStat net.IOCountersStat
// 流量统计
lastIOCounterStat net.IOCountersStat
lastUDPInDatagrams int64
lastUDPOutDatagrams int64
apiCallStat *rpc.CallStat
@@ -44,6 +47,9 @@ func NewNodeStatusExecutor() *NodeStatusExecutor {
return &NodeStatusExecutor{
ticker: time.NewTicker(30 * time.Second),
apiCallStat: rpc.NewCallStat(10),
lastUDPInDatagrams: -1,
lastUDPOutDatagrams: -1,
}
}
@@ -292,14 +298,14 @@ func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus)
// 流量
func (this *NodeStatusExecutor) updateAllTraffic(status *nodeconfigs.NodeStatus) {
counters, err := net.IOCounters(true)
trafficCounters, err := net.IOCounters(true)
if err != nil {
remotelogs.Warn("NODE_STATUS_EXECUTOR", err.Error())
return
}
var allCounter = net.IOCountersStat{}
for _, counter := range counters {
for _, counter := range trafficCounters {
// 跳过lo
if counter.Name == "lo" {
continue
@@ -319,11 +325,49 @@ func (this *NodeStatusExecutor) updateAllTraffic(status *nodeconfigs.NodeStatus)
var bytesSent = allCounter.BytesSent - this.lastIOCounterStat.BytesSent
var bytesRecv = allCounter.BytesRecv - this.lastIOCounterStat.BytesRecv
// UDP
var udpInDatagrams int64 = 0
var udpOutDatagrams int64 = 0
protoStats, protoErr := net.ProtoCounters([]string{"udp"})
if protoErr == nil {
for _, protoStat := range protoStats {
if protoStat.Protocol == "udp" {
udpInDatagrams = protoStat.Stats["InDatagrams"]
udpOutDatagrams = protoStat.Stats["OutDatagrams"]
if udpInDatagrams < 0 {
udpInDatagrams = 0
}
if udpOutDatagrams < 0 {
udpOutDatagrams = 0
}
}
}
}
var avgUDPInDatagrams int64 = 0
var avgUDPOutDatagrams int64 = 0
if this.lastUDPInDatagrams >= 0 && this.lastUDPOutDatagrams >= 0 {
avgUDPInDatagrams = (udpInDatagrams - this.lastUDPInDatagrams) / int64(costSeconds)
avgUDPOutDatagrams = (udpOutDatagrams - this.lastUDPOutDatagrams) / int64(costSeconds)
if avgUDPInDatagrams < 0 {
avgUDPInDatagrams = 0
}
if avgUDPOutDatagrams < 0 {
avgUDPOutDatagrams = 0
}
}
this.lastUDPInDatagrams = udpInDatagrams
this.lastUDPOutDatagrams = udpOutDatagrams
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAllTraffic, maps.Map{
"inBytes": bytesRecv,
"outBytes": bytesSent,
"avgInBytes": bytesRecv / uint64(costSeconds),
"avgOutBytes": bytesSent / uint64(costSeconds),
"avgUDPInDatagrams": avgUDPInDatagrams,
"avgUDPOutDatagrams": avgUDPOutDatagrams,
})
}
}

View File

@@ -82,6 +82,8 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
err = this.execUAMPolicyChangedTask(rpcClient)
case "httpCCPolicyChanged":
err = this.execHTTPCCPolicyChangedTask(rpcClient)
case "http3PolicyChanged":
err = this.execHTTP3PolicyChangedTask(rpcClient)
case "httpPagesPolicyChanged":
err = this.execHTTPPagesPolicyChangedTask(rpcClient)
case "updatingServers":
@@ -128,15 +130,6 @@ func (this *Node) execNodeVersionChangedTask() error {
return nil
}
// 脚本库变更
func (this *Node) execScriptsChangedTask() error {
err := this.reloadCommonScripts()
if err != nil {
return errors.New("reload common scripts failed: " + err.Error())
}
return nil
}
// 节点级别变更
func (this *Node) execNodeLevelChangedTask(rpcClient *rpc.RPCClient) error {
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(rpcClient.Context(), &pb.FindNodeLevelInfoRequest{})
@@ -163,90 +156,6 @@ func (this *Node) execNodeLevelChangedTask(rpcClient *rpc.RPCClient) error {
return nil
}
// UAM策略变更
func (this *Node) execUAMPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating uam policies ...")
resp, err := rpcClient.NodeRPC.FindNodeUAMPolicies(rpcClient.Context(), &pb.FindNodeUAMPoliciesRequest{})
if err != nil {
return err
}
var uamPolicyMap = map[int64]*nodeconfigs.UAMPolicy{}
for _, policy := range resp.UamPolicies {
if len(policy.UamPolicyJSON) > 0 {
var uamPolicy = &nodeconfigs.UAMPolicy{}
err = json.Unmarshal(policy.UamPolicyJSON, uamPolicy)
if err != nil {
remotelogs.Error("NODE", "decode uam policy failed: "+err.Error())
continue
}
err = uamPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize uam policy failed: "+err.Error())
continue
}
uamPolicyMap[policy.NodeClusterId] = uamPolicy
}
}
sharedNodeConfig.UpdateUAMPolicies(uamPolicyMap)
return nil
}
// HTTP CC策略变更
func (this *Node) execHTTPCCPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating http cc policies ...")
resp, err := rpcClient.NodeRPC.FindNodeHTTPCCPolicies(rpcClient.Context(), &pb.FindNodeHTTPCCPoliciesRequest{})
if err != nil {
return err
}
var httpCCPolicyMap = map[int64]*nodeconfigs.HTTPCCPolicy{}
for _, policy := range resp.HttpCCPolicies {
if len(policy.HttpCCPolicyJSON) > 0 {
var httpCCPolicy = nodeconfigs.NewHTTPCCPolicy()
err = json.Unmarshal(policy.HttpCCPolicyJSON, httpCCPolicy)
if err != nil {
remotelogs.Error("NODE", "decode http cc policy failed: "+err.Error())
continue
}
err = httpCCPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize http cc policy failed: "+err.Error())
continue
}
httpCCPolicyMap[policy.NodeClusterId] = httpCCPolicy
}
}
sharedNodeConfig.UpdateHTTPCCPolicies(httpCCPolicyMap)
return nil
}
// 自定义页面策略变更
func (this *Node) execHTTPPagesPolicyChangedTask(rpcClient *rpc.RPCClient) error {
remotelogs.Println("NODE", "updating http pages policies ...")
resp, err := rpcClient.NodeRPC.FindNodeHTTPPagesPolicies(rpcClient.Context(), &pb.FindNodeHTTPPagesPoliciesRequest{})
if err != nil {
return err
}
var httpPagesPolicyMap = map[int64]*nodeconfigs.HTTPPagesPolicy{}
for _, policy := range resp.HttpPagesPolicies {
if len(policy.HttpPagesPolicyJSON) > 0 {
var httpPagesPolicy = nodeconfigs.NewHTTPPagesPolicy()
err = json.Unmarshal(policy.HttpPagesPolicyJSON, httpPagesPolicy)
if err != nil {
remotelogs.Error("NODE", "decode http pages policy failed: "+err.Error())
continue
}
err = httpPagesPolicy.Init()
if err != nil {
remotelogs.Error("NODE", "initialize http pages policy failed: "+err.Error())
continue
}
httpPagesPolicyMap[policy.NodeClusterId] = httpPagesPolicy
}
}
sharedNodeConfig.UpdateHTTPPagesPolicies(httpPagesPolicyMap)
return nil
}
// DDoS配置变更
func (this *Node) execDDoSProtectionChangedTask(rpcClient *rpc.RPCClient) error {
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(rpcClient.Context(), &pb.FindNodeDDoSProtectionRequest{})

View File

@@ -0,0 +1,31 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "github.com/TeaOSLab/EdgeNode/internal/rpc"
func (this *Node) execScriptsChangedTask() error {
// stub
return nil
}
func (this *Node) execUAMPolicyChangedTask(rpcClient *rpc.RPCClient) error {
// stub
return nil
}
func (this *Node) execHTTPCCPolicyChangedTask(rpcClient *rpc.RPCClient) error {
// stub
return nil
}
func (this *Node) execHTTP3PolicyChangedTask(rpcClient *rpc.RPCClient) error {
// stub
return nil
}
func (this *Node) execHTTPPagesPolicyChangedTask(rpcClient *rpc.RPCClient) error {
// stub
return nil
}

View File

@@ -289,6 +289,10 @@ Loop:
for {
select {
case log := <-logChan:
if log.NodeId <= 0 {
continue
}
// 是否已存在
var hash = xxhash.Sum64String(types.String(log.ServerId) + "_" + log.Description)
var found = false
@@ -312,6 +316,7 @@ Loop:
break Loop
}
}
if len(logList) == 0 {
return nil
}

View File

@@ -0,0 +1,15 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package testutils
import "os"
// IsSingleTesting 判断当前测试环境是否为单个函数测试
func IsSingleTesting() bool {
for _, arg := range os.Args {
if arg == "-test.run" {
return true
}
}
return false
}

View File

@@ -82,8 +82,10 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
// output response
if this.StatusCode > 0 {
request.ProcessResponseHeaders(writer.Header(), this.StatusCode)
writer.WriteHeader(this.StatusCode)
} else {
request.ProcessResponseHeaders(writer.Header(), http.StatusForbidden)
writer.WriteHeader(http.StatusForbidden)
}
if len(this.URL) > 0 {

View File

@@ -134,6 +134,7 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
// 占用一次失败次数
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false, false

View File

@@ -67,6 +67,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
return true, false
}
request.ProcessResponseHeaders(writer.Header(), http.StatusFound)
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
flusher, ok := writer.(http.Flusher)

View File

@@ -75,14 +75,15 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
}
}
req.ProcessResponseHeaders(writer.Header(), http.StatusOK)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.Header().Set("Cache-Control", "no-cache")
var timestamp = types.String(time.Now().Unix())
var cookieValue = timestamp + "@" + types.String(set.Id) + "@" + fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId)))
_, _ = writer.Write([]byte(`<!DOCTYPE html>
var respHTML = `<!DOCTYPE html>
<html>
<head>
<title></title>
@@ -94,7 +95,10 @@ window.location.reload();
</head>
<body>
</body>
</html>`))
</html>`
writer.Header().Set("Content-Length", types.String(len(respHTML)))
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(respHTML))
// 记录失败次数
this.increaseFails(req, waf.Id, group.Id, set.Id)

View File

@@ -36,6 +36,7 @@ func (this *PageAction) WillChange() bool {
// Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(this.Status)
_, _ = writer.Write([]byte(request.Format(this.Body)))

View File

@@ -92,6 +92,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
Value: info,
})
request.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
flusher, ok := writer.(http.Flusher)

View File

@@ -146,6 +146,7 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
var expiresAt = time.Now().Unix() + int64(timeout)
if this.Type == "black" {
request.ProcessResponseHeaders(writer.Header(), http.StatusForbidden)
writer.WriteHeader(http.StatusForbidden)
request.WAFClose()

View File

@@ -36,6 +36,7 @@ func (this *RedirectAction) WillChange() bool {
// Perform the action
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Location", this.URL)
writer.WriteHeader(this.Status)

View File

@@ -26,18 +26,22 @@ func NewCaptchaValidator() *CaptchaValidator {
func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter) {
var info = req.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
}
m, err := utils.SimpleDecryptMap(info)
if err != nil {
req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
}
var timestamp = m.GetInt64("timestamp")
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
return
}
@@ -50,16 +54,19 @@ func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWrit
var waf = SharedWAFManager.FindWAF(policyId)
if waf == nil {
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
var actionConfig = waf.FindAction(actionId)
if actionConfig == nil {
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
captchaActionConfig, ok := actionConfig.(*CaptchaAction)
if !ok {
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
@@ -183,8 +190,7 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Req
}
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = writer.Write([]byte(`<!DOCTYPE html>
var msgHTML = `<!DOCTYPE html>
<html>
<head>
<title>` + msgTitle + `</title>
@@ -206,7 +212,13 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Req
</head>
<body>` + body + `
</body>
</html>`))
</html>`
req.ProcessResponseHeaders(writer.Header(), http.StatusOK)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.Header().Set("Content-Length", types.String(len(msgHTML)))
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(msgHTML))
}
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) {
@@ -226,6 +238,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int
// 加入到白名单
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "")
req.ProcessResponseHeaders(writer.Header(), http.StatusSeeOther)
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusSeeOther)
return false
@@ -235,6 +248,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int
return false
}
req.ProcessResponseHeaders(writer.Header(), http.StatusSeeOther)
http.Redirect(writer, req.WAFRaw(), req.WAFRaw().URL.String(), http.StatusSeeOther)
}
}

View File

@@ -22,18 +22,22 @@ func NewGet302Validator() *Get302Validator {
func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) {
var info = request.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
}
m, err := utils.SimpleDecryptMap(info)
if err != nil {
request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
}
var timestamp = m.GetInt64("timestamp")
if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效
request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
@@ -49,5 +53,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
// 返回原始URL
var url = m.GetString("url")
request.ProcessResponseHeaders(writer.Header(), http.StatusFound)
http.Redirect(writer, request.WAFRaw(), url, http.StatusFound)
}

View File

@@ -38,6 +38,9 @@ type Request interface {
// Format 格式化变量
Format(string) string
// ProcessResponseHeaders 处理响应Header
ProcessResponseHeaders(headers http.Header, status int)
// DisableAccessLog 在当前请求中不使用访问日志
DisableAccessLog()
}