Compare commits

..

52 Commits

Author SHA1 Message Date
刘祥超
d40bc4e72b URL跳转可以设置是否保留参数 2022-02-20 09:17:50 +08:00
刘祥超
94d0fc7e88 当压缩格式不在Accept-Encoding中时自动解压 2022-02-18 11:05:09 +08:00
刘祥超
ceaeba7089 修复文件句柄缓存可能重复加入的Bug 2022-02-17 17:38:56 +08:00
刘祥超
a1e868bf29 读取缓存错误更详细 2022-02-17 17:24:35 +08:00
刘祥超
e60af85819 修复从缓存文件中读取压缩内容时可能失败的Bug 2022-02-17 16:56:13 +08:00
刘祥超
7bd24fcc81 检查是否压缩的时候,如果content-type为空,则默认为text/html 2022-02-15 18:31:37 +08:00
刘祥超
4331223916 优化代码 2022-02-15 16:44:39 +08:00
刘祥超
f50113517a 重构对HTTP请求的处理方法:缓存、压缩、WebP、限速 2022-02-15 14:55:49 +08:00
刘祥超
6d6e25f298 WAF规则提示错误时增加分组ID、规则集ID、规则描述 2022-01-29 21:43:42 +08:00
刘祥超
69c89fda48 支持单个服务更新配置 2022-01-19 22:16:46 +08:00
刘祥超
9a56671457 修改版本为v0.4.1 2022-01-17 10:53:23 +08:00
刘祥超
7dde0deb25 TCP源站也支持证书 2022-01-16 19:58:07 +08:00
刘祥超
b4647b1baa 优化验证码在窄屏上的展示 2022-01-16 16:57:25 +08:00
刘祥超
952e3ca572 CAPTCHA增加多个选项 2022-01-16 16:54:13 +08:00
刘祥超
238973a5e2 删除缓存数据库版本切换时的错误提示 2022-01-14 11:48:39 +08:00
刘祥超
9b6ab2fa8b 优化代码 2022-01-14 11:21:28 +08:00
刘祥超
9591004b70 优化代码 2022-01-13 15:49:42 +08:00
刘祥超
00cd86a8b3 优化open file cache,现在能缓存header 2022-01-13 15:18:49 +08:00
刘祥超
2a1cc63989 优化代码 2022-01-13 11:46:42 +08:00
刘祥超
8177768cf6 优化代码 2022-01-13 11:45:51 +08:00
刘祥超
14d156d42d 改进SYN Flood检测 2022-01-13 11:36:05 +08:00
刘祥超
63992bb2a0 优化代码 2022-01-12 21:41:05 +08:00
刘祥超
d02f9f9a0e 实现open file cache 2022-01-12 21:09:00 +08:00
刘祥超
91fab59a18 优化代码 2022-01-12 20:31:04 +08:00
刘祥超
76c82b431a 优化代码 2022-01-11 17:17:58 +08:00
刘祥超
2f6414fc55 优化代码 2022-01-11 16:02:41 +08:00
刘祥超
e23f4aaee2 优化代码 2022-01-11 09:25:34 +08:00
刘祥超
443660ac38 实现自动SYN Flood防护 2022-01-10 19:54:10 +08:00
刘祥超
488430bbef 优化代码 2022-01-10 15:38:53 +08:00
刘祥超
344de90bff 部分请求增加User-Agent 2022-01-10 10:02:15 +08:00
刘祥超
2f02827cb7 优化编译脚本 2022-01-09 20:12:59 +08:00
刘祥超
03e774cc44 自动使用本地防火墙/增加edge-node [ip.drop|ip.reject|ip.remove]等命令 2022-01-09 17:07:37 +08:00
刘祥超
ff2826ab47 优化一处错误提示 2022-01-09 15:32:02 +08:00
刘祥超
ecaa45db34 优化代码 2022-01-09 10:53:21 +08:00
刘祥超
b6cc826a54 优化正则表达式/修复一些测试用例 2022-01-08 12:20:18 +08:00
刘祥超
b8d7e3f5b4 提升WAF正则表达式性能(提升20%以上) 2022-01-08 11:45:14 +08:00
刘祥超
390be7f6c6 增加${browser.xxx}相关变量 2022-01-06 17:05:04 +08:00
刘祥超
ac4e240912 国家、省份数据不再每个小时更新一次;WAF增加国家/地区、省份、城市、ISP等参数 2022-01-06 16:27:39 +08:00
刘祥超
be7267211b 统计数据上传时如果遇到invalid utf-8,则自动过滤非法字符/统计数据上传失败时,仍然丢弃已有的统计数据,防止数据堆积 2022-01-05 16:05:58 +08:00
刘祥超
88fa75acb5 优化代码 2022-01-03 21:50:51 +08:00
刘祥超
d62fccf0a4 如果源站返回的内容长度为0,则不再尝试读取数据 2022-01-03 18:10:02 +08:00
刘祥超
258ffef0c2 尝试自动在firewalld中开放端口 2022-01-03 16:27:34 +08:00
刘祥超
a41f834192 可以打印服务相关日志信息 2022-01-03 15:53:59 +08:00
刘祥超
00500cb6a3 优化代码 2022-01-02 22:45:37 +08:00
刘祥超
32a3400138 优化代码 2022-01-01 22:02:46 +08:00
刘祥超
5ae4ef665e 优化UserAgent解析 2022-01-01 21:47:59 +08:00
刘祥超
336db828ad 优化代码 2022-01-01 20:15:39 +08:00
刘祥超
a1212804bb 增加edge-node gc命令 2022-01-01 17:18:34 +08:00
刘祥超
763ab4ac98 优化代码 2021-12-31 19:51:56 +08:00
刘祥超
4ec6ae4301 优化文字提示 2021-12-31 19:46:33 +08:00
刘祥超
4f292c5003 如果没有设置节点CPU线程数,则默认为4倍的CPU线程数 2021-12-31 19:45:54 +08:00
刘祥超
a00325f41a 修改版本为0.4.0 2021-12-31 15:19:20 +08:00
125 changed files with 4476 additions and 1370 deletions

View File

@@ -6,6 +6,9 @@ function build() {
VERSION=$(lookup-version $ROOT/../internal/const/const.go)
DIST=$ROOT/"../dist/${NAME}"
MUSL_DIR="/usr/local/opt/musl-cross/bin"
GCC_X86_64_DIR="/usr/local/Cellar/x86_64-unknown-linux-gnu/10.3.0/bin"
GCC_ARM64_DIR="/usr/local/Cellar/aarch64-unknown-linux-gnu/10.3.0/bin"
OS=${1}
ARCH=${2}
TAG=${3}
@@ -56,19 +59,39 @@ function build() {
CC_PATH=""
CXX_PATH=""
BUILD_TAG=$TAG
if [[ `uname -a` == *"Darwin"* && "${OS}" == "linux" ]]; then
# /usr/local/opt/musl-cross/bin/
if [ "${ARCH}" == "amd64" ]; then
CC_PATH="x86_64-linux-musl-gcc"
CXX_PATH="x86_64-linux-musl-g++"
# build with script support
if [ -d $GCC_X86_64_DIR ]; then
MUSL_DIR=$GCC_X86_64_DIR
CC_PATH="x86_64-unknown-linux-gnu-gcc"
CXX_PATH="x86_64-unknown-linux-gnu-g++"
if [ "$TAG" = "plus" ]; then
BUILD_TAG="plus,script"
fi
else
CC_PATH="x86_64-linux-musl-gcc"
CXX_PATH="x86_64-linux-musl-g++"
fi
fi
if [ "${ARCH}" == "386" ]; then
CC_PATH="i486-linux-musl-gcc"
CXX_PATH="i486-linux-musl-g++"
fi
if [ "${ARCH}" == "arm64" ]; then
CC_PATH="aarch64-linux-musl-gcc"
CXX_PATH="aarch64-linux-musl-g++"
# build with script support
if [ -d $GCC_ARM64_DIR ]; then
MUSL_DIR=$GCC_ARM64_DIR
CC_PATH="aarch64-unknown-linux-gnu-gcc"
CXX_PATH="aarch64-unknown-linux-gnu-g++"
if [ "$TAG" = "plus" ]; then
BUILD_TAG="plus,script"
fi
else
CC_PATH="aarch64-linux-musl-gcc"
CXX_PATH="aarch64-linux-musl-g++"
fi
fi
if [ "${ARCH}" == "arm" ]; then
CC_PATH="arm-linux-musleabi-gcc"
@@ -84,7 +107,7 @@ function build() {
fi
fi
if [ ! -z $CC_PATH ]; then
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -tags $TAG -o $DIST/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" $ROOT/../cmd/edge-node/main.go
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -tags $BUILD_TAG -o $DIST/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" $ROOT/../cmd/edge-node/main.go
else
env GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -tags $TAG -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go
fi

View File

@@ -8,7 +8,10 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/nodes"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"net"
"net/http"
_ "net/http/pprof"
"os"
@@ -117,6 +120,122 @@ func main() {
}
}
})
app.On("gc", func() {
var sock = gosock.NewTmpSock(teaconst.ProcessName)
_, err := sock.Send(&gosock.Command{Code: "gc"})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
fmt.Println("ok")
}
})
app.On("ip.drop", func() {
var args = os.Args[2:]
if len(args) == 0 {
fmt.Println("Usage: edge-node ip.drop IP [--timeout=SECONDS]")
return
}
var ip = args[0]
if len(net.ParseIP(ip)) == 0 {
fmt.Println("IP '" + ip + "' is invalid")
return
}
var timeoutSeconds = 0
var options = app.ParseOptions(args[1:])
timeout, ok := options["timeout"]
if ok {
timeoutSeconds = types.Int(timeout[0])
}
fmt.Println("drop ip '" + ip + "' for '" + types.String(timeoutSeconds) + "' seconds")
var sock = gosock.NewTmpSock(teaconst.ProcessName)
reply, err := sock.Send(&gosock.Command{
Code: "dropIP",
Params: map[string]interface{}{
"ip": ip,
"timeoutSeconds": timeoutSeconds,
},
})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
var errString = maps.NewMap(reply.Params).GetString("error")
if len(errString) > 0 {
fmt.Println("[ERROR]" + errString)
} else {
fmt.Println("ok")
}
}
})
app.On("ip.reject", func() {
var args = os.Args[2:]
if len(args) == 0 {
fmt.Println("Usage: edge-node ip.reject IP [--timeout=SECONDS]")
return
}
var ip = args[0]
if len(net.ParseIP(ip)) == 0 {
fmt.Println("IP '" + ip + "' is invalid")
return
}
var timeoutSeconds = 0
var options = app.ParseOptions(args[1:])
timeout, ok := options["timeout"]
if ok {
timeoutSeconds = types.Int(timeout[0])
}
fmt.Println("reject ip '" + ip + "' for '" + types.String(timeoutSeconds) + "' seconds")
var sock = gosock.NewTmpSock(teaconst.ProcessName)
reply, err := sock.Send(&gosock.Command{
Code: "rejectIP",
Params: map[string]interface{}{
"ip": ip,
"timeoutSeconds": timeoutSeconds,
},
})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
var errString = maps.NewMap(reply.Params).GetString("error")
if len(errString) > 0 {
fmt.Println("[ERROR]" + errString)
} else {
fmt.Println("ok")
}
}
})
app.On("ip.remove", func() {
var args = os.Args[2:]
if len(args) == 0 {
fmt.Println("Usage: edge-node ip.remove IP")
return
}
var ip = args[0]
if len(net.ParseIP(ip)) == 0 {
fmt.Println("IP '" + ip + "' is invalid")
return
}
var sock = gosock.NewTmpSock(teaconst.ProcessName)
reply, err := sock.Send(&gosock.Command{
Code: "removeIP",
Params: map[string]interface{}{
"ip": ip,
},
})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
var errString = maps.NewMap(reply.Params).GetString("error")
if len(errString) > 0 {
fmt.Println("[ERROR]" + errString)
} else {
fmt.Println("ok")
}
}
})
app.Run(func() {
node := nodes.NewNode()
node.Start()

7
go.mod
View File

@@ -2,7 +2,9 @@ module github.com/TeaOSLab/EdgeNode
go 1.15
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
replace (
github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
)
require (
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
@@ -11,6 +13,7 @@ require (
github.com/cespare/xxhash v1.1.0
github.com/chai2010/webp v1.1.0 // indirect
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/go-yaml/yaml v2.1.0+incompatible
github.com/golang/protobuf v1.5.2
github.com/iwind/TeaGo v0.0.0-20211026123858-7de7a21cad24
@@ -28,7 +31,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/net v0.0.0-20211215060638-4ddde0e984e9
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d
golang.org/x/sys v0.0.0-20220111092808-5a964db01320
golang.org/x/text v0.3.7
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa // indirect
google.golang.org/grpc v1.43.0

7
go.sum
View File

@@ -43,6 +43,8 @@ github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI=
github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
@@ -198,8 +200,11 @@ golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d h1:1oIt9o40TWWI9FUaveVpUvBe13FNqBNVXy3ue2fcfkw=
golang.org/x/sys v0.0.0-20211214234402-4825e8c3871d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320 h1:0jf+tOCoZ3LyutmCOWpVni1chK4VfFLhRsDK7MhqGRY=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@@ -271,5 +276,3 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclp
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
rogchap.com/v8go v0.7.0 h1:kgjbiO4zE5itA962ze6Hqmbs4HgZbGzmueCXsZtremg=
rogchap.com/v8go v0.7.0/go.mod h1:MxgP3pL2MW4dpme/72QRs8sgNMmM0pRc8DPhcuLWPAs=

View File

@@ -11,6 +11,7 @@ import (
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"time"
)
@@ -245,3 +246,19 @@ func (this *AppCmd) getPID() int {
}
return maps.NewMap(reply.Params).GetInt("pid")
}
// ParseOptions 分析参数中的选项
func (this *AppCmd) ParseOptions(args []string) map[string][]string {
var result = map[string][]string{}
for _, arg := range args {
var pieces = strings.SplitN(arg, "=", 2)
var key = strings.TrimLeft(pieces[0], "- ")
key = strings.TrimSpace(key)
var value = ""
if len(pieces) == 2 {
value = strings.TrimSpace(pieces[1])
}
result[key] = append(result[key], value)
}
return result
}

View File

@@ -206,6 +206,7 @@ func (this *FileList) Add(hash string, item *Item) error {
return err
}
this.memoryCache.Write(hash, 1, item.ExpiredAt)
atomic.AddInt64(&this.total, 1)
if this.onAdd != nil {
@@ -258,9 +259,10 @@ func (this *FileList) CleanPrefix(prefix string) error {
}()
var count = int64(10000)
var staleLife = 600 // TODO 需要可以设置
var staleLife = 600 // TODO 需要可以设置
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
for {
result, err := this.db.Exec(`UPDATE "`+this.itemsTableName+`" SET expiredAt=0,staleAt=? WHERE id IN (SELECT id FROM "`+this.itemsTableName+`" WHERE expiredAt>0 AND createdAt<=? AND INSTR("key", ?)=1 LIMIT `+types.String(count)+`)`, utils.UnixTime()+int64(staleLife), utils.UnixTime(), prefix)
result, err := this.db.Exec(`UPDATE "`+this.itemsTableName+`" SET expiredAt=0,staleAt=? WHERE id IN (SELECT id FROM "`+this.itemsTableName+`" WHERE expiredAt>0 AND createdAt<=? AND INSTR("key", ?)=1 LIMIT `+types.String(count)+`)`, unixTime+int64(staleLife), unixTime, prefix)
if err != nil {
return err
}
@@ -558,9 +560,7 @@ ON "` + this.itemsTableName + `" (
// v2 => v3
remotelogs.Println("CACHE", "transferring old data from v2 to v3 ...")
result, err := db.Exec(`INSERT INTO "` + this.itemsTableName + `" ("id", "hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "createdAt", "host", "serverId", "staleAt") SELECT "id", "hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "createdAt", "host", "serverId", "expiredAt"+600 FROM cacheItems_v2`)
if err != nil {
remotelogs.Println("CACHE", "transfer old data from v2 to v3 failed: "+err.Error())
} else {
if err == nil {
count, _ := result.RowsAffected()
remotelogs.Println("CACHE", "transfer old data from v2 to v3 finished, "+types.String(count)+" rows transferred")
}

View File

@@ -85,7 +85,7 @@ func (this *Manager) UpdatePolicies(newPolicies []*serverconfigs.HTTPCachePolicy
for _, policy := range this.policyMap {
storage, ok := this.storageMap[policy.Id]
if !ok {
storage := this.NewStorageWithPolicy(policy)
storage = this.NewStorageWithPolicy(policy)
if storage == nil {
remotelogs.Error("CACHE", "can not find storage type '"+policy.Type+"'")
continue
@@ -106,7 +106,7 @@ func (this *Manager) UpdatePolicies(newPolicies []*serverconfigs.HTTPCachePolicy
delete(this.storageMap, policy.Id)
// 启动新的
storage := this.NewStorageWithPolicy(policy)
storage = this.NewStorageWithPolicy(policy)
if storage == nil {
remotelogs.Error("CACHE", "can not find storage type '"+policy.Type+"'")
continue

View File

@@ -0,0 +1,33 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package caches
import (
"io"
"os"
)
type OpenFile struct {
fp *os.File
meta []byte
header []byte
version int64
}
func NewOpenFile(fp *os.File, meta []byte, header []byte) *OpenFile {
return &OpenFile{
fp: fp,
meta: meta,
header: header,
}
}
func (this *OpenFile) SeekStart() error {
_, err := this.fp.Seek(0, io.SeekStart)
return err
}
func (this *OpenFile) Close() error {
this.meta = nil
return this.fp.Close()
}

View File

@@ -0,0 +1,161 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package caches
import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
"github.com/fsnotify/fsnotify"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"path/filepath"
"sync"
"time"
)
type OpenFileCache struct {
poolMap map[string]*OpenFilePool // file path => Pool
poolList *linkedlist.List
watcher *fsnotify.Watcher
locker sync.Mutex
maxSize int
count int
}
func NewOpenFileCache(maxSize int) (*OpenFileCache, error) {
if maxSize <= 0 {
maxSize = 16384
}
var cache = &OpenFileCache{
maxSize: maxSize,
poolMap: map[string]*OpenFilePool{},
poolList: linkedlist.NewList(),
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
cache.watcher = watcher
goman.New(func() {
for event := range watcher.Events {
if event.Op&fsnotify.Chmod != fsnotify.Chmod {
cache.Close(event.Name)
}
}
})
return cache, nil
}
func (this *OpenFileCache) Get(filename string) *OpenFile {
this.locker.Lock()
defer this.locker.Unlock()
pool, ok := this.poolMap[filename]
if ok {
file, consumed := pool.Get()
if consumed {
this.count--
}
return file
}
return nil
}
func (this *OpenFileCache) Put(filename string, file *OpenFile) {
this.locker.Lock()
defer this.locker.Unlock()
pool, ok := this.poolMap[filename]
var success bool
if ok {
success = pool.Put(file)
} else {
_ = this.watcher.Add(filename)
pool = NewOpenFilePool(filename)
this.poolMap[filename] = pool
success = pool.Put(file)
}
this.poolList.Push(pool.linkItem)
// 检查长度
if success {
this.count++
// 如果超过当前容量,则关闭最早的
if this.count > this.maxSize {
var delta = this.maxSize / 100 // 清理1%
if delta == 0 {
delta = 1
}
for i := 0; i < delta; i++ {
var head = this.poolList.Head()
if head == nil {
break
}
var headPool = head.Value.(*OpenFilePool)
headFile, consumed := headPool.Get()
if consumed {
this.count--
if headFile != nil {
_ = headFile.Close()
}
}
if headPool.Len() == 0 {
delete(this.poolMap, headPool.filename)
this.poolList.Remove(head)
_ = this.watcher.Remove(headPool.filename)
}
}
}
}
}
func (this *OpenFileCache) Close(filename string) {
this.locker.Lock()
pool, ok := this.poolMap[filename]
if ok {
delete(this.poolMap, filename)
this.poolList.Remove(pool.linkItem)
_ = this.watcher.Remove(filename)
this.count -= pool.Len()
}
this.locker.Unlock()
// 在locker之外提升性能
if ok {
pool.Close()
}
}
func (this *OpenFileCache) CloseAll() {
this.locker.Lock()
for _, pool := range this.poolMap {
pool.Close()
}
this.poolMap = map[string]*OpenFilePool{}
this.poolList.Reset()
_ = this.watcher.Close()
this.locker.Unlock()
}
func (this *OpenFileCache) Debug() {
var ticker = time.NewTicker(5 * time.Second)
goman.New(func() {
for range ticker.C {
logs.Println("==== " + types.String(this.count) + " ====")
this.poolList.Range(func(item *linkedlist.Item) (goNext bool) {
logs.Println(filepath.Base(item.Value.(*OpenFilePool).Filename()), item.Value.(*OpenFilePool).Len())
return true
})
}
})
}

View File

@@ -0,0 +1,76 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package caches
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
)
type OpenFilePool struct {
c chan *OpenFile
linkItem *linkedlist.Item
filename string
version int64
}
func NewOpenFilePool(filename string) *OpenFilePool {
var pool = &OpenFilePool{
filename: filename,
c: make(chan *OpenFile, 1024),
version: utils.UnixTimeMilli(),
}
pool.linkItem = linkedlist.NewItem(pool)
return pool
}
func (this *OpenFilePool) Filename() string {
return this.filename
}
func (this *OpenFilePool) Get() (*OpenFile, bool) {
select {
case file := <-this.c:
err := file.SeekStart()
if err != nil {
_ = file.Close()
return nil, true
}
file.version = this.version
return file, true
default:
return nil, false
}
}
func (this *OpenFilePool) Put(file *OpenFile) bool {
if file.version > 0 && file.version != this.version {
_ = file.Close()
return false
}
select {
case this.c <- file:
return true
default:
// 多余的直接关闭
_ = file.Close()
return false
}
}
func (this *OpenFilePool) Len() int {
return len(this.c)
}
func (this *OpenFilePool) Close() {
Loop:
for {
select {
case file := <-this.c:
_ = file.Close()
default:
break Loop
}
}
}

View File

@@ -0,0 +1,17 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"testing"
)
func TestOpenFilePool_Get(t *testing.T) {
var pool = caches.NewOpenFilePool("a")
t.Log(pool.Filename())
t.Log(pool.Get())
t.Log(pool.Put(caches.NewOpenFile(nil, nil, []byte{})))
t.Log(pool.Get())
t.Log(pool.Get())
}

View File

@@ -11,6 +11,12 @@ import (
type FileReader struct {
fp *os.File
openFile *OpenFile
openFileCache *OpenFileCache
meta []byte
header []byte
expiresAt int64
status int
headerOffset int64
@@ -18,8 +24,7 @@ type FileReader struct {
bodySize int64
bodyOffset int64
bodyBufLen int
bodyBuf []byte
isClosed bool
}
func NewFileReader(fp *os.File) *FileReader {
@@ -27,6 +32,11 @@ func NewFileReader(fp *os.File) *FileReader {
}
func (this *FileReader) Init() error {
if this.openFile != nil {
this.meta = this.openFile.meta
this.header = this.openFile.header
}
isOk := false
defer func() {
@@ -35,13 +45,17 @@ func (this *FileReader) Init() error {
}
}()
var buf = make([]byte, SizeMeta)
ok, err := this.readToBuff(this.fp, buf)
if err != nil {
return err
}
if !ok {
return ErrNotFound
var buf = this.meta
if len(buf) == 0 {
buf = make([]byte, SizeMeta)
ok, err := this.readToBuff(this.fp, buf)
if err != nil {
return err
}
if !ok {
return ErrNotFound
}
this.meta = buf
}
this.expiresAt = int64(binary.BigEndian.Uint32(buf[:SizeExpiresAt]))
@@ -72,6 +86,21 @@ func (this *FileReader) Init() error {
this.bodySize = int64(bodySize)
this.bodyOffset = this.headerOffset + int64(headerSize)
// read header
if this.openFileCache != nil && len(this.header) == 0 {
if headerSize > 0 && headerSize <= 512 {
this.header = make([]byte, headerSize)
_, err := this.fp.Seek(this.headerOffset, io.SeekStart)
if err != nil {
return err
}
_, err = this.readToBuff(this.fp, this.header)
if err != nil {
return err
}
}
}
isOk = true
return nil
@@ -106,6 +135,22 @@ func (this *FileReader) BodySize() int64 {
}
func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
// 使用缓存
if len(this.header) > 0 && len(buf) >= len(this.header) {
copy(buf, this.header)
_, err := callback(len(this.header))
if err != nil {
return err
}
// 移动到Body位置
_, err = this.fp.Seek(this.bodyOffset, io.SeekStart)
if err != nil {
return err
}
return nil
}
isOk := false
defer func() {
@@ -135,10 +180,6 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
}
headerSize -= n
} else {
if n > headerSize {
this.bodyBuf = buf[headerSize:]
this.bodyBufLen = n - headerSize
}
_, e := callback(headerSize)
if e != nil {
isOk = true
@@ -157,6 +198,12 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
isOk = true
// 移动到Body位置
_, err = this.fp.Seek(this.bodyOffset, io.SeekStart)
if err != nil {
return err
}
return nil
}
@@ -169,27 +216,7 @@ func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
}
}()
offset := this.bodyOffset
// 直接返回从Header中剩余的
if this.bodyBufLen > 0 && len(buf) >= this.bodyBufLen {
offset += int64(this.bodyBufLen)
copy(buf, this.bodyBuf)
isOk = true
goNext, err := callback(this.bodyBufLen)
if err != nil {
return err
}
if !goNext {
return nil
}
if this.bodySize <= int64(this.bodyBufLen) {
return nil
}
}
var offset = this.bodyOffset
// 开始读Body部分
_, err := this.fp.Seek(offset, io.SeekStart)
@@ -223,32 +250,9 @@ func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
}
func (this *FileReader) Read(buf []byte) (n int, err error) {
var isOk = false
defer func() {
if !isOk {
_ = this.discard()
}
}()
// 直接返回从Header中剩余的
if this.bodyBufLen > 0 && len(buf) >= this.bodyBufLen {
copy(buf, this.bodyBuf)
isOk = true
n = this.bodyBufLen
if this.bodySize <= int64(this.bodyBufLen) {
err = io.EOF
return
}
this.bodyBufLen = 0
return
}
n, err = this.fp.Read(buf)
if err == nil || err == io.EOF {
isOk = true
if err != nil && err != io.EOF {
_ = this.discard()
}
return
}
@@ -323,6 +327,19 @@ func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callba
}
func (this *FileReader) Close() error {
if this.openFileCache != nil {
if this.isClosed {
return nil
}
this.isClosed = true
if this.openFile != nil {
this.openFileCache.Put(this.fp.Name(), this.openFile)
} else {
this.openFileCache.Put(this.fp.Name(), NewOpenFile(this.fp, this.meta, this.header))
}
return nil
}
return this.fp.Close()
}
@@ -337,5 +354,6 @@ func (this *FileReader) readToBuff(fp *os.File, buf []byte) (ok bool, err error)
func (this *FileReader) discard() error {
_ = this.fp.Close()
this.isClosed = true
return os.Remove(this.fp.Name())
}

View File

@@ -14,12 +14,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"golang.org/x/text/language"
"golang.org/x/text/message"
"io"
"math"
"os"
"path/filepath"
@@ -67,6 +67,8 @@ type FileStorage struct {
hotMapLocker sync.Mutex
lastHotSize int
hotTicker *utils.Ticker
openFileCache *OpenFileCache
}
func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage {
@@ -100,13 +102,13 @@ func (this *FileStorage) Init() error {
return err
}
this.cacheConfig = cacheConfig
cacheDir := cacheConfig.Dir
if !filepath.IsAbs(this.cacheConfig.Dir) {
this.cacheConfig.Dir = Tea.Root + Tea.DS + this.cacheConfig.Dir
}
dir := this.cacheConfig.Dir
this.cacheConfig.Dir = filepath.Clean(this.cacheConfig.Dir)
var dir = this.cacheConfig.Dir
if len(dir) == 0 {
return errors.New("[CACHE]cache storage dir can not be empty")
@@ -159,7 +161,7 @@ func (this *FileStorage) Init() error {
} else if size > 1024 {
sizeMB = fmt.Sprintf("%.3f K", float64(size)/1024)
}
remotelogs.Println("CACHE", "init policy "+strconv.FormatInt(this.policy.Id, 10)+" from '"+cacheDir+"', cost: "+fmt.Sprintf("%.2f", cost)+" ms, count: "+message.NewPrinter(language.English).Sprintf("%d", count)+", size: "+sizeMB)
remotelogs.Println("CACHE", "init policy "+strconv.FormatInt(this.policy.Id, 10)+" from '"+this.cacheConfig.Dir+"', cost: "+fmt.Sprintf("%.2f", cost)+" ms, count: "+message.NewPrinter(language.English).Sprintf("%d", count)+", size: "+sizeMB)
}()
// 初始化list
@@ -202,6 +204,15 @@ func (this *FileStorage) Init() error {
}
}
// open file cache
if this.cacheConfig.OpenFileCache != nil && this.cacheConfig.OpenFileCache.IsOn && this.cacheConfig.OpenFileCache.Max > 0 {
this.openFileCache, err = NewOpenFileCache(this.cacheConfig.OpenFileCache.Max)
logs.Println("start open file cache")
if err != nil {
remotelogs.Error("CACHE", "open file cache failed: "+err.Error())
}
}
return nil
}
@@ -238,12 +249,22 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool)
// TODO 尝试使用mmap加快读取速度
var isOk = false
fp, err := os.OpenFile(path, os.O_RDONLY, 0444)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
var openFile *OpenFile
if this.openFileCache != nil {
openFile = this.openFileCache.Get(path)
}
var fp *os.File
var err error
if openFile == nil {
fp, err = os.OpenFile(path, os.O_RDONLY, 0444)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
}
return nil, ErrNotFound
}
return nil, ErrNotFound
} else {
fp = openFile.fp
}
defer func() {
if !isOk {
@@ -252,10 +273,9 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool)
}
}()
reader := NewFileReader(fp)
if err != nil {
return nil, err
}
var reader = NewFileReader(fp)
reader.openFile = openFile
reader.openFileCache = this.openFileCache
err = reader.Init()
if err != nil {
return nil, err
@@ -305,10 +325,11 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool)
}
// OpenWriter 打开缓存文件等待写入
func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int) (Writer, error) {
func (this *FileStorage) OpenWriter(key string, expiredAt int64, status int, size int64) (Writer, error) {
// 先尝试内存缓存
if this.memoryStorage != nil {
writer, err := this.memoryStorage.OpenWriter(key, expiredAt, status)
// 我们限定仅小文件优先存在内存中
if this.memoryStorage != nil && size > 0 && size < 32*1024*1024 {
writer, err := this.memoryStorage.OpenWriter(key, expiredAt, status, size)
if err == nil {
return writer, nil
}
@@ -623,6 +644,8 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
// Stop 停止
func (this *FileStorage) Stop() {
events.Remove(this)
this.locker.Lock()
defer this.locker.Unlock()
@@ -640,6 +663,10 @@ func (this *FileStorage) Stop() {
}
_ = this.list.Close()
if this.openFileCache != nil {
this.openFileCache.CloseAll()
}
}
// TotalDiskSize 消耗的磁盘尺寸
@@ -702,11 +729,20 @@ func (this *FileStorage) initList() error {
}
}
this.purgeTicker = utils.NewTicker(time.Duration(autoPurgeInterval) * time.Second)
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("CACHE", "quit clean timer")
var ticker = this.purgeTicker
if ticker != nil {
ticker.Stop()
{
var ticker = this.purgeTicker
if ticker != nil {
ticker.Stop()
}
}
{
var ticker = this.hotTicker
if ticker != nil {
ticker.Stop()
}
}
})
goman.New(func() {
@@ -733,98 +769,6 @@ func (this *FileStorage) initList() error {
return nil
}
// 解析文件信息
func (this *FileStorage) decodeFile(path string) (*Item, error) {
fp, err := os.OpenFile(path, os.O_RDONLY, 0444)
if err != nil {
return nil, err
}
isAllOk := false
defer func() {
_ = fp.Close()
if !isAllOk {
_ = os.Remove(path)
}
}()
item := &Item{
Type: ItemTypeFile,
MetaSize: SizeMeta,
}
bytes4 := make([]byte, 4)
// 过期时间
ok, err := this.readToBuff(fp, bytes4)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrNotFound
}
item.ExpiredAt = int64(binary.BigEndian.Uint32(bytes4))
// 是否已过期
if item.ExpiredAt < time.Now().Unix() {
return nil, ErrNotFound
}
// URL Size
_, err = fp.Seek(int64(SizeExpiresAt+SizeStatus), io.SeekStart)
if err != nil {
return nil, err
}
ok, err = this.readToBuff(fp, bytes4)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrNotFound
}
urlSize := binary.BigEndian.Uint32(bytes4)
// Header Size
ok, err = this.readToBuff(fp, bytes4)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrNotFound
}
item.HeaderSize = int64(binary.BigEndian.Uint32(bytes4))
// Body Size
bytes8 := make([]byte, 8)
ok, err = this.readToBuff(fp, bytes8)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrNotFound
}
item.BodySize = int64(binary.BigEndian.Uint64(bytes8))
// URL
if urlSize > 0 {
data := utils.BytePool1k.Get()
result, ok, err := this.readN(fp, data, int(urlSize))
utils.BytePool1k.Put(data)
if err != nil {
return nil, err
}
if !ok {
return nil, ErrNotFound
}
item.Key = string(result)
}
isAllOk = true
return item, nil
}
// 清理任务
func (this *FileStorage) purgeLoop() {
// 计算是否应该开启LFU清理
@@ -961,7 +905,7 @@ func (this *FileStorage) hotLoop() {
continue
}
writer, err := this.memoryStorage.openWriter(item.Key, item.ExpiresAt, item.Status, false)
writer, err := this.memoryStorage.openWriter(item.Key, item.ExpiresAt, item.Status, reader.BodySize(), false)
if err != nil {
if !CanIgnoreErr(err) {
remotelogs.Error("CACHE", "transfer hot item failed: "+err.Error())
@@ -1008,34 +952,6 @@ func (this *FileStorage) hotLoop() {
}
}
func (this *FileStorage) readToBuff(fp *os.File, buf []byte) (ok bool, err error) {
n, err := fp.Read(buf)
if err != nil {
return false, err
}
ok = n == len(buf)
return
}
func (this *FileStorage) readN(fp *os.File, buf []byte, total int) (result []byte, ok bool, err error) {
for {
n, err := fp.Read(buf)
if err != nil {
return nil, false, err
}
if n > 0 {
if n >= total {
result = append(result, buf[:total]...)
ok = true
return result, ok, nil
} else {
total -= n
result = append(result, buf[:n]...)
}
}
}
}
func (this *FileStorage) diskCapacityBytes() int64 {
c1 := this.policy.CapacityBytes()
if SharedManager.MaxDiskCapacity != nil {

View File

@@ -62,7 +62,7 @@ func TestFileStorage_OpenWriter(t *testing.T) {
header := []byte("Header")
body := []byte("This is Body")
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200)
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -104,7 +104,7 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
t.Log(time.Since(now).Seconds()*1000, "ms")
}()
writer, err := storage.OpenWriter("my-http-response", time.Now().Unix()+86400, 200)
writer, err := storage.OpenWriter("my-http-response", time.Now().Unix()+86400, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -177,7 +177,7 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
go func(i int) {
defer wg.Done()
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200)
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200, -1)
if err != nil {
if err != ErrFileIsWriting {
t.Fatal(err)
@@ -229,7 +229,7 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
go func(i int) {
defer wg.Done()
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200)
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200, -1)
if err != nil {
if err != ErrFileIsWriting {
t.Fatal(err)
@@ -482,11 +482,7 @@ func TestFileStorage_DecodeFile(t *testing.T) {
t.Fatal(err)
}
_, path := storage.keyPath("my-key")
item, err := storage.decodeFile(path)
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(item, t)
t.Log(path)
}
func BenchmarkFileStorage_Read(b *testing.B) {

View File

@@ -13,7 +13,7 @@ type StorageInterface interface {
OpenReader(key string, useStale bool) (reader Reader, err error)
// OpenWriter 打开缓存写入器等待写入
OpenWriter(key string, expiredAt int64, status int) (Writer, error)
OpenWriter(key string, expiredAt int64, status int, size int64) (Writer, error)
// Delete 删除某个键值对应的缓存
Delete(key string) error

View File

@@ -145,11 +145,11 @@ func (this *MemoryStorage) OpenReader(key string, useStale bool) (Reader, error)
}
// OpenWriter 打开缓存写入器等待写入
func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int) (Writer, error) {
return this.openWriter(key, expiredAt, status, true)
func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, size int64) (Writer, error) {
return this.openWriter(key, expiredAt, status, size, true)
}
func (this *MemoryStorage) openWriter(key string, expiredAt int64, status int, isDirty bool) (Writer, error) {
func (this *MemoryStorage) openWriter(key string, expiredAt int64, status int, size int64, isDirty bool) (Writer, error) {
this.locker.Lock()
defer this.locker.Unlock()
@@ -182,7 +182,10 @@ func (this *MemoryStorage) openWriter(key string, expiredAt int64, status int, i
return nil, NewCapacityError("write memory cache failed: too many keys in cache storage")
}
capacityBytes := this.memoryCapacityBytes()
if capacityBytes > 0 && capacityBytes <= this.totalSize {
if size < 0 {
size = 0
}
if capacityBytes > 0 && capacityBytes <= this.totalSize+size {
return nil, NewCapacityError("write memory cache failed: over memory size: " + strconv.FormatInt(capacityBytes, 10) + ", current size: " + strconv.FormatInt(this.totalSize, 10) + " bytes")
}
@@ -384,7 +387,7 @@ func (this *MemoryStorage) flushItem(key string) {
return
}
writer, err := this.parentStorage.OpenWriter(key, item.ExpiredAt, item.Status)
writer, err := this.parentStorage.OpenWriter(key, item.ExpiredAt, item.Status, -1)
if err != nil {
if !CanIgnoreErr(err) {
remotelogs.Error("CACHE", "flush items failed: open writer failed: "+err.Error())

View File

@@ -15,7 +15,7 @@ import (
func TestMemoryStorage_OpenWriter(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -62,7 +62,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
}
}
writer, err = storage.OpenWriter("abc", time.Now().Unix()+60, 200)
writer, err = storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -103,7 +103,7 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) {
func TestMemoryStorage_Delete(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
{
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -111,7 +111,7 @@ func TestMemoryStorage_Delete(t *testing.T) {
t.Log(len(storage.valuesMap))
}
{
writer, err := storage.OpenWriter("abc1", time.Now().Unix()+60, 200)
writer, err := storage.OpenWriter("abc1", time.Now().Unix()+60, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -126,7 +126,7 @@ func TestMemoryStorage_Stat(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -139,7 +139,7 @@ func TestMemoryStorage_Stat(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -163,7 +163,7 @@ func TestMemoryStorage_CleanAll(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -175,7 +175,7 @@ func TestMemoryStorage_CleanAll(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -198,7 +198,7 @@ func TestMemoryStorage_Purge(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -210,7 +210,7 @@ func TestMemoryStorage_Purge(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}
@@ -241,7 +241,7 @@ func TestMemoryStorage_Expire(t *testing.T) {
for i := 0; i < 1000; i++ {
expiredAt := time.Now().Unix() + int64(rands.Int(0, 60))
key := "abc" + strconv.Itoa(i)
writer, err := storage.OpenWriter(key, expiredAt, 200)
writer, err := storage.OpenWriter(key, expiredAt, 200, -1)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,76 +0,0 @@
package caches
import (
"github.com/TeaOSLab/EdgeNode/internal/compressions"
)
type compressionWriter struct {
rawWriter Writer
writer compressions.Writer
key string
expiredAt int64
}
func NewCompressionWriter(gw Writer, cpWriter compressions.Writer, key string, expiredAt int64) Writer {
return &compressionWriter{
rawWriter: gw,
writer: cpWriter,
key: key,
expiredAt: expiredAt,
}
}
func (this *compressionWriter) WriteHeader(data []byte) (n int, err error) {
return this.writer.Write(data)
}
// WriteHeaderLength 写入Header长度数据
func (this *compressionWriter) WriteHeaderLength(headerLength int) error {
return nil
}
// WriteBodyLength 写入Body长度数据
func (this *compressionWriter) WriteBodyLength(bodyLength int64) error {
return nil
}
func (this *compressionWriter) Write(data []byte) (n int, err error) {
return this.writer.Write(data)
}
func (this *compressionWriter) Close() error {
err := this.writer.Close()
if err != nil {
return err
}
return this.rawWriter.Close()
}
func (this *compressionWriter) Discard() error {
err := this.writer.Close()
if err != nil {
return err
}
return this.rawWriter.Discard()
}
func (this *compressionWriter) Key() string {
return this.key
}
func (this *compressionWriter) ExpiredAt() int64 {
return this.expiredAt
}
func (this *compressionWriter) HeaderSize() int64 {
return this.rawWriter.HeaderSize()
}
func (this *compressionWriter) BodySize() int64 {
return this.rawWriter.BodySize()
}
// ItemType 内容类型
func (this *compressionWriter) ItemType() ItemType {
return this.rawWriter.ItemType()
}

View File

@@ -0,0 +1,68 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package compressions
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"io"
"os"
"testing"
)
func TestGzipReader(t *testing.T) {
fp, err := os.Open("/Users/WorkSpace/EdgeProject/EdgeCache/p43/36/7e/367e02720713fe05b66573a1d69b4f0a.cache")
if err != nil {
// not fatal
t.Log(err)
return
}
defer func() {
_ = fp.Close()
}()
var buf = make([]byte, 32*1024)
cacheReader := caches.NewFileReader(fp)
err = cacheReader.Init()
if err != nil {
t.Fatal(err)
}
var headerBuf = []byte{}
err = cacheReader.ReadHeader(buf, func(n int) (goNext bool, err error) {
headerBuf = append(headerBuf, buf[:n]...)
for {
nIndex := bytes.Index(headerBuf, []byte{'\n'})
if nIndex >= 0 {
row := headerBuf[:nIndex]
spaceIndex := bytes.Index(row, []byte{':'})
if spaceIndex <= 0 {
return false, errors.New("invalid header '" + string(row) + "'")
}
headerBuf = headerBuf[nIndex+1:]
} else {
break
}
}
return true, nil
})
reader, err := NewGzipReader(cacheReader)
if err != nil {
t.Fatal(err)
}
for {
n, err := reader.Read(buf)
if err != nil {
if err != io.EOF {
t.Fatal(err)
} else {
break
}
}
t.Log(string(buf[:n]))
_ = n
}
}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.3.8"
Version = "0.4.1"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -2,6 +2,8 @@
package teaconst
import "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
var (
// 流量统计
@@ -10,4 +12,6 @@ var (
NodeId int64 = 0
NodeIdString = ""
GlobalProductName = nodeconfigs.DefaultProductName
)

View File

@@ -1,27 +1,71 @@
package events
import "sync"
import (
"sync"
)
var eventsMap = map[string][]func(){} // event => []callbacks
type Callbacks = []func()
var eventsMap = map[Event]map[interface{}]Callbacks{} // event => map[event key][]callback
var locker = sync.Mutex{}
// 增加事件回调
func On(event string, callback func()) {
var eventKeyId = 0
func NewKey() interface{} {
locker.Lock()
defer locker.Unlock()
eventKeyId++
return eventKeyId
}
// On 增加事件回调
func On(event Event, callback func()) {
OnKey(event, nil, callback)
}
// OnKey 使用Key增加事件回调
func OnKey(event Event, key interface{}, callback func()) {
if key == nil {
key = NewKey()
}
locker.Lock()
defer locker.Unlock()
callbacks, _ := eventsMap[event]
callbacks = append(callbacks, callback)
eventsMap[event] = callbacks
m, ok := eventsMap[event]
if !ok {
m = map[interface{}]Callbacks{}
eventsMap[event] = m
}
m[key] = append(m[key], callback)
}
// 通知事件
func Notify(event string) {
// Remove 删除事件回调
func Remove(key interface{}) {
if key == nil {
return
}
locker.Lock()
callbacks, _ := eventsMap[event]
for k, m := range eventsMap {
_, ok := m[key]
if ok {
delete(m, key)
eventsMap[k] = m
}
}
locker.Unlock()
}
// Notify 通知事件
func Notify(event Event) {
locker.Lock()
m := eventsMap[event]
locker.Unlock()
for _, callback := range callbacks {
callback()
for _, callbacks := range m {
for _, callback := range callbacks {
callback()
}
}
}

View File

@@ -1,16 +1,33 @@
package events
package events_test
import "testing"
import (
"github.com/TeaOSLab/EdgeNode/internal/events"
"testing"
)
func TestOn(t *testing.T) {
On("hello", func() {
type User struct {
name string
}
var u = &User{}
var u2 = &User{}
events.On("hello", func() {
t.Log("world")
})
On("hello", func() {
events.On("hello", func() {
t.Log("world2")
})
On("hello2", func() {
events.OnKey("hello", u, func() {
t.Log("world3")
})
events.OnKey("hello", u, func() {
t.Log("world4")
})
events.Remove(u)
events.Remove(u2)
events.OnKey("hello2", nil, func() {
t.Log("world2")
})
Notify("hello")
events.Notify("hello")
}

View File

@@ -0,0 +1,42 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
import (
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
)
var currentFirewall FirewallInterface
// 初始化
func init() {
events.On(events.EventLoaded, func() {
var firewall = Firewall()
if firewall.Name() == "mock" {
remotelogs.Warn("FIREWALL", "'firewalld' on this system should be enabled to block attackers more effectively")
} else {
remotelogs.Println("FIREWALL", "found local firewall '"+firewall.Name()+"'")
}
})
}
// Firewall 查找当前系统中最适合的防火墙
func Firewall() FirewallInterface {
if currentFirewall != nil {
return currentFirewall
}
// firewalld
{
var firewalld = NewFirewalld()
if firewalld.IsReady() {
currentFirewall = firewalld
return currentFirewall
}
}
// 至少返回一个
currentFirewall = NewMockFirewall()
return currentFirewall
}

View File

@@ -0,0 +1,135 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"os/exec"
"strings"
)
type Firewalld struct {
isReady bool
exe string
cmdQueue chan *exec.Cmd
}
func NewFirewalld() *Firewalld {
var firewalld = &Firewalld{
cmdQueue: make(chan *exec.Cmd, 4096),
}
path, err := exec.LookPath("firewall-cmd")
if err == nil && len(path) > 0 {
var cmd = exec.Command(path, "-V")
err := cmd.Run()
if err == nil {
firewalld.exe = path
firewalld.isReady = true
firewalld.init()
}
}
return firewalld
}
func (this *Firewalld) init() {
goman.New(func() {
for cmd := range this.cmdQueue {
err := cmd.Run()
if err != nil {
if strings.HasPrefix(err.Error(), "Warning:") {
continue
}
remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
}
}
})
}
// Name 名称
func (this *Firewalld) Name() string {
return "firewalld"
}
func (this *Firewalld) IsReady() bool {
return this.isReady
}
func (this *Firewalld) AllowPort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) RemovePort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' reject"}
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' drop"}
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) RemoveSourceIP(ip string) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
for _, action := range []string{"reject", "drop"} {
var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
}
return nil
}
func (this *Firewalld) pushCmd(cmd *exec.Cmd) {
select {
case this.cmdQueue <- cmd:
default:
// we discard the command
}
}

View File

@@ -0,0 +1,27 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
// FirewallInterface 防火墙接口
type FirewallInterface interface {
// Name 名称
Name() string
// IsReady 是否已准备被调用
IsReady() bool
// AllowPort 允许端口
AllowPort(port int, protocol string) error
// RemovePort 删除端口
RemovePort(port int, protocol string) error
// RejectSourceIP 拒绝某个源IP连接
RejectSourceIP(ip string, timeoutSeconds int) error
// DropSourceIP 丢弃某个源IP数据
DropSourceIP(ip string, timeoutSeconds int) error
// RemoveSourceIP 删除某个源IP
RemoveSourceIP(ip string) error
}

View File

@@ -0,0 +1,55 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
// MockFirewall 模拟防火墙
type MockFirewall struct {
}
func NewMockFirewall() *MockFirewall {
return &MockFirewall{}
}
// Name 名称
func (this *MockFirewall) Name() string {
return "mock"
}
// IsReady 是否已准备被调用
func (this *MockFirewall) IsReady() bool {
return true
}
// AllowPort 允许端口
func (this *MockFirewall) AllowPort(port int, protocol string) error {
_ = port
_ = protocol
return nil
}
// RemovePort 删除端口
func (this *MockFirewall) RemovePort(port int, protocol string) error {
_ = port
_ = protocol
return nil
}
// RejectSourceIP 拒绝某个源IP连接
func (this *MockFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
_ = ip
_ = timeoutSeconds
return nil
}
// DropSourceIP 丢弃某个源IP数据
func (this *MockFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
_ = ip
_ = timeoutSeconds
return nil
}
// RemoveSourceIP 删除某个源IP
func (this *MockFirewall) RemoveSourceIP(ip string) error {
_ = ip
return nil
}

View File

@@ -71,7 +71,7 @@ func (this *HTTPAPIAction) runAction(action string, listType IPListType, item *p
if err != nil {
return err
}
req.Header.Set("User-Agent", "GoEdge-Node/"+teaconst.Version)
req.Header.Set("User-Agent", teaconst.GlobalProductName+"-Node/"+teaconst.Version)
resp, err := httpAPIClient.Do(req)
if err != nil {
return err

View File

@@ -7,6 +7,7 @@ import (
)
// AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) bool {
var ipLong = utils.IP2Long(ip)
if ipLong == 0 {
@@ -40,6 +41,17 @@ func AllowIP(ip string, serverId int64) bool {
return true
}
// IsInWhiteList 检查IP是否在白名单中
func IsInWhiteList(ip string) bool {
var ipLong = utils.IP2Long(ip)
if ipLong == 0 {
return false
}
// check white lists
return GlobalWhiteIPList.Contains(ipLong)
}
// AllowIPStrings 检查一组IP是否被允许访问
func AllowIPStrings(ipStrings []string, serverId int64) bool {
if len(ipStrings) == 0 {

View File

@@ -0,0 +1,156 @@
package iplibrary
import (
"crypto/md5"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/types"
"io/ioutil"
"os"
"sync"
"time"
)
var SharedCityManager = NewCityManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedCityManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedCityManager.Stop()
})
}
// CityManager 中国省份信息管理
type CityManager struct {
ticker *time.Ticker
cacheFile string
cityMap map[string]int64 // provinceName_cityName => cityName
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewCityManager() *CityManager {
return &CityManager{
cacheFile: Tea.Root + "/configs/region_city.json.cache",
cityMap: map[string]int64{},
}
}
func (this *CityManager) Start() {
// 从缓存中读取
err := this.load()
if err != nil {
remotelogs.ErrorObject("CITY_MANAGER", err)
}
// 第一次更新
err = this.loop()
if err != nil {
remotelogs.ErrorObject("City_MANAGER", err)
}
// 定时更新
this.ticker = time.NewTicker(4 * time.Hour)
for range this.ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("CITY_MANAGER", err)
}
}
}
func (this *CityManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *CityManager) Lookup(provinceId int64, cityName string) (cityId int64) {
this.locker.RLock()
cityId, _ = this.cityMap[types.String(provinceId)+"_"+cityName]
this.locker.RUnlock()
return
}
// 从缓存中读取
func (this *CityManager) load() error {
data, err := ioutil.ReadFile(this.cacheFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
m := map[string]int64{}
err = json.Unmarshal(data, &m)
if err != nil {
return err
}
if m != nil && len(m) > 0 {
this.cityMap = m
}
return nil
}
// 更新城市信息
func (this *CityManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionCityRPC().FindAllEnabledRegionCities(rpcClient.Context(), &pb.FindAllEnabledRegionCitiesRequest{})
if err != nil {
return err
}
m := map[string]int64{}
for _, city := range resp.RegionCities {
for _, code := range city.Codes {
m[types.String(city.RegionProvinceId)+"_"+code] = city.Id
}
}
// 检查是否有更新
data, err := json.Marshal(m)
if err != nil {
return err
}
hash := md5.New()
hash.Write(data)
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
if this.dataHash == dataHash {
return nil
}
this.dataHash = dataHash
this.locker.Lock()
this.cityMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存
err = ioutil.WriteFile(this.cacheFile, data, 0666)
return err
}

View File

@@ -0,0 +1,14 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import "testing"
func TestNewCityManager(t *testing.T) {
var manager = NewCityManager()
err := manager.loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.Lookup(16, "许昌市"))
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"io/ioutil"
@@ -26,16 +25,23 @@ func init() {
SharedCountryManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedCountryManager.Stop()
})
}
// CountryManager 国家/地区信息管理
type CountryManager struct {
ticker *time.Ticker
cacheFile string
countryMap map[string]int64 // countryName => countryId
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewCountryManager() *CountryManager {
@@ -59,11 +65,8 @@ func (this *CountryManager) Start() {
}
// 定时更新
ticker := utils.NewTicker(1 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
for ticker.Next() {
this.ticker = time.NewTicker(4 * time.Hour)
for range this.ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
@@ -71,6 +74,12 @@ func (this *CountryManager) Start() {
}
}
func (this *CountryManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *CountryManager) Lookup(countryName string) (countryId int64) {
this.locker.RLock()
countryId, _ = this.countryMap[countryName]
@@ -101,6 +110,10 @@ func (this *CountryManager) load() error {
// 更新国家信息
func (this *CountryManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
@@ -111,7 +124,7 @@ func (this *CountryManager) loop() error {
}
m := map[string]int64{}
for _, country := range resp.Countries {
for _, country := range resp.RegionCountries {
for _, code := range country.Codes {
m[code] = country.Id
}
@@ -132,6 +145,7 @@ func (this *CountryManager) loop() error {
this.locker.Lock()
this.countryMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存

View File

@@ -23,10 +23,15 @@ func init() {
SharedIPListManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedIPListManager.Stop()
})
}
// IPListManager IP名单管理
type IPListManager struct {
ticker *time.Ticker
db *IPListDB
version int64
@@ -52,17 +57,14 @@ func (this *IPListManager) Start() {
remotelogs.ErrorObject("IP_LIST_MANAGER", err)
}
ticker := time.NewTicker(60 * time.Second)
this.ticker = time.NewTicker(60 * time.Second)
if Tea.IsTesting() {
ticker = time.NewTicker(10 * time.Second)
this.ticker = time.NewTicker(10 * time.Second)
}
events.On(events.EventQuit, func() {
ticker.Stop()
})
countErrors := 0
for {
select {
case <-ticker.C:
case <-this.ticker.C:
case <-IPListUpdateNotify:
}
err := this.loop()
@@ -84,6 +86,12 @@ func (this *IPListManager) Start() {
}
}
func (this *IPListManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *IPListManager) init() {
// 从数据库中当中读取数据
db, err := NewIPListDB()
@@ -197,7 +205,7 @@ func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool)
list.Delete(item.Id)
// 从WAF名单中删除
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId)
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, shouldExecute)
// 操作事件
if shouldExecute {

View File

@@ -0,0 +1,155 @@
package iplibrary
import (
"crypto/md5"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"io/ioutil"
"os"
"sync"
"time"
)
var SharedProviderManager = NewProviderManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedProviderManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedProviderManager.Stop()
})
}
// ProviderManager 中国省份信息管理
type ProviderManager struct {
ticker *time.Ticker
cacheFile string
providerMap map[string]int64 // name => id
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewProviderManager() *ProviderManager {
return &ProviderManager{
cacheFile: Tea.Root + "/configs/region_provider.json.cache",
providerMap: map[string]int64{},
}
}
func (this *ProviderManager) Start() {
// 从缓存中读取
err := this.load()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
// 第一次更新
err = this.loop()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
// 定时更新
this.ticker = time.NewTicker(4 * time.Hour)
for range this.ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
}
}
func (this *ProviderManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *ProviderManager) Lookup(providerName string) (providerId int64) {
this.locker.RLock()
providerId, _ = this.providerMap[providerName]
this.locker.RUnlock()
return
}
// 从缓存中读取
func (this *ProviderManager) load() error {
data, err := ioutil.ReadFile(this.cacheFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
m := map[string]int64{}
err = json.Unmarshal(data, &m)
if err != nil {
return err
}
if m != nil && len(m) > 0 {
this.providerMap = m
}
return nil
}
// 更新服务商信息
func (this *ProviderManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionProviderRPC().FindAllEnabledRegionProviders(rpcClient.Context(), &pb.FindAllEnabledRegionProvidersRequest{})
if err != nil {
return err
}
m := map[string]int64{}
for _, provider := range resp.RegionProviders {
for _, code := range provider.Codes {
m[code] = provider.Id
}
}
// 检查是否有更新
data, err := json.Marshal(m)
if err != nil {
return err
}
hash := md5.New()
hash.Write(data)
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
if this.dataHash == dataHash {
return nil
}
this.dataHash = dataHash
this.locker.Lock()
this.providerMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存
err = ioutil.WriteFile(this.cacheFile, data, 0666)
return err
}

View File

@@ -0,0 +1,15 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import "testing"
func TestNewProviderManager(t *testing.T) {
var manager = NewProviderManager()
err := manager.loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.Lookup("阿里云"))
t.Log(manager.Lookup("阿里云2"))
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"io/ioutil"
@@ -30,16 +29,23 @@ func init() {
SharedProvinceManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedProvinceManager.Stop()
})
}
// ProvinceManager 中国省份信息管理
type ProvinceManager struct {
ticker *time.Ticker
cacheFile string
provinceMap map[string]int64 // provinceName => provinceId
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewProvinceManager() *ProvinceManager {
@@ -63,11 +69,8 @@ func (this *ProvinceManager) Start() {
}
// 定时更新
ticker := utils.NewTicker(1 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
for ticker.Next() {
this.ticker = time.NewTicker(4 * time.Hour)
for range this.ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
@@ -75,6 +78,12 @@ func (this *ProvinceManager) Start() {
}
}
func (this *ProvinceManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *ProvinceManager) Lookup(provinceName string) (provinceId int64) {
this.locker.RLock()
provinceId, _ = this.provinceMap[provinceName]
@@ -103,21 +112,25 @@ func (this *ProvinceManager) load() error {
return nil
}
// 更新国家信息
// 更新省份信息
func (this *ProvinceManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionProvinceRPC().FindAllEnabledRegionProvincesWithCountryId(rpcClient.Context(), &pb.FindAllEnabledRegionProvincesWithCountryIdRequest{
CountryId: ChinaCountryId,
RegionCountryId: ChinaCountryId,
})
if err != nil {
return err
}
m := map[string]int64{}
for _, province := range resp.Provinces {
for _, province := range resp.RegionProvinces {
for _, code := range province.Codes {
m[code] = province.Id
}
@@ -138,6 +151,7 @@ func (this *ProvinceManager) loop() error {
this.locker.Lock()
this.provinceMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存

View File

@@ -15,15 +15,22 @@ import (
"time"
)
var SharedUpdater = NewUpdater()
func init() {
events.On(events.EventStart, func() {
updater := NewUpdater()
updater.Start()
goman.New(func() {
SharedUpdater.Start()
})
})
events.On(events.EventQuit, func() {
SharedUpdater.Stop()
})
}
// Updater IP库更新程序
type Updater struct {
ticker *time.Ticker
}
// NewUpdater 获取新对象
@@ -34,15 +41,19 @@ func NewUpdater() *Updater {
// Start 开始更新
func (this *Updater) Start() {
// 这里不需要太频繁检查更新因为通常不需要更新IP库
ticker := time.NewTicker(1 * time.Hour)
goman.New(func() {
for range ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("IP_LIBRARY", err)
}
this.ticker = time.NewTicker(1 * time.Hour)
for range this.ticker.C {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("IP_LIBRARY", err)
}
})
}
}
func (this *Updater) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
// 单次任务

View File

@@ -41,7 +41,7 @@ func NewAPIStream() *APIStream {
}
func (this *APIStream) Start() {
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
this.isQuiting = true
if this.cancelFunc != nil {
this.cancelFunc()
@@ -182,7 +182,7 @@ func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error {
}
expiredAt := time.Now().Unix() + msg.LifeSeconds
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200)
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, int64(len(msg.Value)))
if err != nil {
this.replyFail(message.RequestId, "prepare writing failed: "+err.Error())
return err
@@ -462,7 +462,7 @@ func (this *APIStream) handlePreheatCache(message *pb.NodeStreamMessage) error {
}
expiredAt := time.Now().Unix() + 8600
writer, err := storage.OpenWriter(key, expiredAt, 200) // TODO 可以设置缓存过期时间
writer, err := storage.OpenWriter(key, expiredAt, 200, resp.ContentLength) // TODO 可以设置缓存过期时间
if err != nil {
locker.Lock()
errorMessages = append(errorMessages, "open cache writer failed: "+key+": "+err.Error())

View File

@@ -4,9 +4,16 @@ package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/types"
"net"
"os"
"sync"
"sync/atomic"
"time"
@@ -17,8 +24,11 @@ type ClientConn struct {
once sync.Once
globalLimiter *ratelimit.Counter
isTLS bool
hasRead bool
isTLS bool
hasDeadline bool
hasRead bool
hasResetSYNFlood bool
BaseClientConn
}
@@ -38,9 +48,9 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
func (this *ClientConn) Read(b []byte) (n int, err error) {
if this.isTLS {
if !this.hasRead {
if !this.hasDeadline {
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
this.hasRead = true
this.hasDeadline = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
@@ -50,7 +60,26 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
if !this.hasRead {
this.hasRead = true
}
}
// SYN Flood检测
var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
if isHandshakeError {
_ = this.SetLinger(0)
}
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
if synFloodConfig != nil && synFloodConfig.IsOn {
if isHandshakeError {
this.increaseSYNFlood(synFloodConfig)
} else if err == nil && !this.hasResetSYNFlood {
this.hasResetSYNFlood = true
this.resetSYNFlood()
}
}
return
}
@@ -99,3 +128,30 @@ func (this *ClientConn) SetReadDeadline(t time.Time) error {
func (this *ClientConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientConn) resetSYNFlood() {
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
}
func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloodConfig) {
var ip = this.RawIP()
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
var timestamp = utils.NextMinuteUnixTime()
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp)
var minAttempts = synFloodConfig.MinAttempts
if minAttempts < 5 {
minAttempts = 5
}
if !this.isTLS {
// 非TLS设置为两倍防止误封
minAttempts = 2 * minAttempts
}
if result >= int64(minAttempts) {
var timeout = synFloodConfig.TimeoutSeconds
if timeout <= 0 {
timeout = 600
}
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, time.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击当前1分钟"+types.String(result)+"次空连接")
}
}
}

View File

@@ -36,3 +36,23 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS
return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP)
}
// RawIP 原本IP
func (this *BaseClientConn) RawIP() string {
ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
return ip
}
// TCPConn 转换为TCPConn
func (this *BaseClientConn) TCPConn() (*net.TCPConn, bool) {
conn, ok := this.rawConn.(*net.TCPConn)
return conn, ok
}
// SetLinger 设置Linger
func (this *BaseClientConn) SetLinger(seconds int) error {
tcpConn, ok := this.TCPConn()
if ok {
return tcpConn.SetLinger(seconds)
}
return nil
}

View File

@@ -16,7 +16,7 @@ import (
// 发送监控流量
func init() {
events.On(events.EventStart, func() {
ticker := time.NewTicker(1 * time.Minute)
var ticker = time.NewTicker(1 * time.Minute)
goman.New(func() {
for range ticker.C {
// 加入到数据队列中

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"strings"
"time"
)
@@ -108,37 +109,33 @@ Loop:
}
func (this *HTTPAccessLogQueue) toValidUTF8(accessLog *pb.HTTPAccessLog) {
accessLog.RemoteUser = this.toValidUTF8string(accessLog.RemoteUser)
accessLog.RequestURI = this.toValidUTF8string(accessLog.RequestURI)
accessLog.RequestPath = this.toValidUTF8string(accessLog.RequestPath)
accessLog.RequestFilename = this.toValidUTF8string(accessLog.RequestFilename)
accessLog.RemoteUser = utils.ToValidUTF8string(accessLog.RemoteUser)
accessLog.RequestURI = utils.ToValidUTF8string(accessLog.RequestURI)
accessLog.RequestPath = utils.ToValidUTF8string(accessLog.RequestPath)
accessLog.RequestFilename = utils.ToValidUTF8string(accessLog.RequestFilename)
accessLog.RequestBody = bytes.ToValidUTF8(accessLog.RequestBody, []byte{})
for _, v := range accessLog.SentHeader {
for index, s := range v.Values {
v.Values[index] = this.toValidUTF8string(s)
v.Values[index] = utils.ToValidUTF8string(s)
}
}
accessLog.Referer = this.toValidUTF8string(accessLog.Referer)
accessLog.UserAgent = this.toValidUTF8string(accessLog.UserAgent)
accessLog.Request = this.toValidUTF8string(accessLog.Request)
accessLog.ContentType = this.toValidUTF8string(accessLog.ContentType)
accessLog.Referer = utils.ToValidUTF8string(accessLog.Referer)
accessLog.UserAgent = utils.ToValidUTF8string(accessLog.UserAgent)
accessLog.Request = utils.ToValidUTF8string(accessLog.Request)
accessLog.ContentType = utils.ToValidUTF8string(accessLog.ContentType)
for k, c := range accessLog.Cookie {
accessLog.Cookie[k] = this.toValidUTF8string(c)
accessLog.Cookie[k] = utils.ToValidUTF8string(c)
}
accessLog.Args = this.toValidUTF8string(accessLog.Args)
accessLog.QueryString = this.toValidUTF8string(accessLog.QueryString)
accessLog.Args = utils.ToValidUTF8string(accessLog.Args)
accessLog.QueryString = utils.ToValidUTF8string(accessLog.QueryString)
for _, v := range accessLog.Header {
for index, s := range v.Values {
v.Values[index] = this.toValidUTF8string(s)
v.Values[index] = utils.ToValidUTF8string(s)
}
}
}
func (this *HTTPAccessLogQueue) toValidUTF8string(v string) string {
return strings.ToValidUTF8(v, "")
}

View File

@@ -84,14 +84,13 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
if idleConns <= 0 {
idleConns = numberCPU * 8
}
//logs.Println("[ORIGIN]max connections:", maxConnections)
// TLS通讯
tlsConfig := &tls.Config{
var tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
if origin.Cert != nil {
obj := origin.Cert.CertObject()
var obj = origin.Cert.CertObject()
if obj != nil {
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{*obj}
@@ -101,37 +100,16 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
}
}
transport := &http.Transport{
var transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 支持TOA的连接
toaConfig := sharedTOAManager.Config()
if toaConfig != nil && toaConfig.IsOn {
retries := 3
for i := 1; i <= retries; i++ {
port := int(toaConfig.RandLocalPort())
// TODO 思考是否支持X-Real-IP/X-Forwarded-IP
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.requestRemoteAddr(true))
if err != nil {
remotelogs.Error("TOA", "add failed: "+err.Error())
} else {
dialer := net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
LocalAddr: &net.TCPAddr{
Port: port,
},
}
conn, err := dialer.DialContext(ctx, network, originAddr)
// TODO 需要在合适的时机删除TOA记录
if err == nil || i == retries {
return conn, err
}
}
}
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
if conn != nil || err != nil {
return conn, err
}
// 普通的连接
conn, err := (&net.Dialer{
conn, err = (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, originAddr)
@@ -139,32 +117,10 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
return nil, err
}
if proxyProtocol != nil && proxyProtocol.IsOn && (proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = req.requestRemoteAddr(true)
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr, ":") {
transportProtocol = proxyproto.TCPv6
}
var destAddr = conn.RemoteAddr()
var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
if reqConn != nil {
destAddr = reqConn.(net.Conn).LocalAddr()
}
header := proxyproto.Header{
Version: byte(proxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(remoteAddr),
Port: req.requestRemotePort(),
},
DestinationAddr: destAddr,
}
_, err = header.WriteTo(conn)
if err != nil {
_ = conn.Close()
return nil, err
}
// 处理PROXY protocol
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
if err != nil {
return nil, err
}
return conn, nil
@@ -174,7 +130,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
TLSClientConfig: tlsConfig,
Proxy: nil,
}
@@ -208,3 +164,69 @@ func (this *HTTPClientPool) cleanClients() {
this.locker.Unlock()
}
}
// 支持TOA
func (this *HTTPClientPool) handleTOA(req *HTTPRequest, ctx context.Context, network string, originAddr string, connectionTimeout time.Duration) (net.Conn, error) {
// TODO 每个服务读取自身所属集群的TOA设置
toaConfig := sharedTOAManager.Config()
if toaConfig != nil && toaConfig.IsOn {
retries := 3
for i := 1; i <= retries; i++ {
port := int(toaConfig.RandLocalPort())
// TODO 思考是否支持X-Real-IP/X-Forwarded-IP
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.requestRemoteAddr(true))
if err != nil {
remotelogs.Error("TOA", "add failed: "+err.Error())
} else {
dialer := net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
LocalAddr: &net.TCPAddr{
Port: port,
},
}
conn, err := dialer.DialContext(ctx, network, originAddr)
// TODO 需要在合适的时机删除TOA记录
if err == nil || i == retries {
return conn, err
}
}
}
}
return nil, nil
}
// 支持PROXY Protocol
func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, proxyProtocol *serverconfigs.ProxyProtocolConfig) error {
if proxyProtocol != nil && proxyProtocol.IsOn && (proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = req.requestRemoteAddr(true)
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr, ":") {
transportProtocol = proxyproto.TCPv6
}
var destAddr = conn.RemoteAddr()
var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
if reqConn != nil {
destAddr = reqConn.(net.Conn).LocalAddr()
}
header := proxyproto.Header{
Version: byte(proxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(remoteAddr),
Port: req.requestRemotePort(),
},
DestinationAddr: destAddr,
}
_, err := header.WriteTo(conn)
if err != nil {
_ = conn.Close()
return err
}
return nil
}
return nil
}

View File

@@ -8,10 +8,12 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"io"
"io/ioutil"
@@ -38,8 +40,8 @@ type HTTPRequest struct {
// 外部参数
RawReq *http.Request
RawWriter http.ResponseWriter
Server *serverconfigs.ServerConfig
host string // 请求的Host
ReqServer *serverconfigs.ServerConfig
ReqHost string // 请求的Host
ServerName string // 实际匹配到的Host
ServerAddr string // 实际启动的服务器监听地址
IsHTTP bool
@@ -63,6 +65,7 @@ type HTTPRequest struct {
rewriteRule *serverconfigs.HTTPRewriteRule // 匹配到的重写规则
rewriteReplace string // 重写规则的目标
rewriteIsExternalURL bool // 重写目标是否为外部URL
remoteAddr string // 计算后的RemoteAddr
cacheRef *serverconfigs.HTTPCacheRef // 缓存设置
cacheKey string // 缓存使用的Key
@@ -98,7 +101,7 @@ func (this *HTTPRequest) init() {
// this.uri = this.RawReq.URL.RequestURI()
// 之所以不使用RequestURI()是不想让URL中的Path被Encode
var urlPath = this.RawReq.URL.Path
if this.Server.Web != nil && this.Server.Web.MergeSlashes {
if this.ReqServer.Web != nil && this.ReqServer.Web.MergeSlashes {
urlPath = utils.CleanPath(urlPath)
this.web.MergeSlashes = true
}
@@ -129,13 +132,13 @@ func (this *HTTPRequest) Do() {
this.init()
// 当前服务的反向代理配置
if this.Server.ReverseProxyRef != nil && this.Server.ReverseProxy != nil {
this.reverseProxyRef = this.Server.ReverseProxyRef
this.reverseProxy = this.Server.ReverseProxy
if this.ReqServer.ReverseProxyRef != nil && this.ReqServer.ReverseProxy != nil {
this.reverseProxyRef = this.ReqServer.ReverseProxyRef
this.reverseProxy = this.ReqServer.ReverseProxy
}
// Web配置
err := this.configureWeb(this.Server.Web, true, 0)
err := this.configureWeb(this.ReqServer.Web, true, 0)
if err != nil {
this.write50x(err, http.StatusInternalServerError, false)
this.doEnd()
@@ -161,14 +164,14 @@ func (this *HTTPRequest) Do() {
}
// 套餐
if this.Server.UserPlan != nil && !this.Server.UserPlan.IsAvailable() {
if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() {
this.doPlanExpires()
this.doEnd()
return
}
// 流量限制
if this.Server.TrafficLimit != nil && this.Server.TrafficLimit.IsOn && !this.Server.TrafficLimit.IsEmpty() && this.Server.TrafficLimitStatus != nil && this.Server.TrafficLimitStatus.IsValid() {
if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
this.doTrafficLimit()
this.doEnd()
return
@@ -310,14 +313,14 @@ func (this *HTTPRequest) doEnd() {
// 流量统计
// TODO 增加是否开启开关
// TODO 增加Header统计考虑从Conn中读取
if this.Server != nil {
if this.ReqServer != nil {
if this.isCached {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), this.writer.SentBodyBytes(), 1, 1, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
} else {
if this.isAttack {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), 0, 1, 0, 1, this.writer.SentBodyBytes(), this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
} else {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes(), 0, 1, 0, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
}
}
}
@@ -467,11 +470,17 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
if this.web.RequestScripts == nil {
this.web.RequestScripts = web.RequestScripts
} else {
if web.RequestScripts.OnInitScript != nil && (web.RequestScripts.OnInitScript.IsPrior || isTop) {
this.web.RequestScripts.OnInitScript = web.RequestScripts.OnInitScript
if web.RequestScripts.InitGroup != nil && (web.RequestScripts.InitGroup.IsPrior || isTop) {
if this.web.RequestScripts == nil {
this.web.RequestScripts = &serverconfigs.HTTPRequestScriptsConfig{}
}
this.web.RequestScripts.InitGroup = web.RequestScripts.InitGroup
}
if web.RequestScripts.OnRequestScript != nil && (web.RequestScripts.OnRequestScript.IsPrior || isTop) {
this.web.RequestScripts.OnRequestScript = web.RequestScripts.OnRequestScript
if web.RequestScripts.RequestGroup != nil && (web.RequestScripts.RequestGroup.IsPrior || isTop) {
if this.web.RequestScripts == nil {
this.web.RequestScripts = &serverconfigs.HTTPRequestScriptsConfig{}
}
this.web.RequestScripts.RequestGroup = web.RequestScripts.RequestGroup
}
}
}
@@ -545,7 +554,7 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
}
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
// 检查专属域名
if len(location.Domains) > 0 && !configutils.MatchDomains(location.Domains, this.host) {
if len(location.Domains) > 0 && !configutils.MatchDomains(location.Domains, this.ReqHost) {
continue
}
@@ -627,7 +636,7 @@ func (this *HTTPRequest) Format(source string) string {
if this.IsHTTPS {
scheme = "https"
}
return scheme + "://" + this.host + this.rawURI
return scheme + "://" + this.ReqHost + this.rawURI
case "requestPath":
return this.Path()
case "requestPathExtension":
@@ -674,7 +683,7 @@ func (this *HTTPRequest) Format(source string) string {
case "timestamp":
return strconv.FormatInt(this.requestFromTime.Unix(), 10)
case "host":
return this.host
return this.ReqHost
case "referer":
return this.RawReq.Referer()
case "referer.host":
@@ -792,7 +801,7 @@ func (this *HTTPRequest) Format(source string) string {
// host
if prefix == "host" {
pieces := strings.Split(this.host, ".")
pieces := strings.Split(this.ReqHost, ".")
switch suffix {
case "first":
if len(pieces) > 0 {
@@ -857,6 +866,104 @@ func (this *HTTPRequest) Format(source string) string {
}
}
// geo
if prefix == "geo" {
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
switch suffix {
case "country.name":
if result != nil {
return result.Country
}
return ""
case "country.id":
if result != nil {
return types.String(iplibrary.SharedCountryManager.Lookup(result.Country))
}
return "0"
case "province.name":
if result != nil {
return result.Province
}
return ""
case "province.id":
if result != nil {
return types.String(iplibrary.SharedProvinceManager.Lookup(result.Province))
}
return "0"
case "city.name":
if result != nil {
return result.City
}
return ""
case "city.id":
if result != nil {
var provinceId = iplibrary.SharedProvinceManager.Lookup(result.Province)
if provinceId > 0 {
return types.String(iplibrary.SharedCityManager.Lookup(provinceId, result.City))
} else {
return "0"
}
}
return "0"
}
}
// ips
if prefix == "isp" {
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
switch suffix {
case "name":
if result != nil {
return result.ISP
}
case "id":
if result != nil {
return types.String(iplibrary.SharedProviderManager.Lookup(result.ISP))
}
return "0"
}
return ""
}
// browser
if prefix == "browser" {
var result = stats.SharedUserAgentParser.Parse(this.RawReq.UserAgent())
switch suffix {
case "os.name":
return result.OS.Name
case "os.version":
return result.OS.Version
case "name":
return result.BrowserName
case "version":
return result.BrowserVersion
case "isMobile":
if result.IsMobile {
return "1"
} else {
return "0"
}
}
}
// product
if prefix == "product" {
switch suffix {
case "name":
if sharedNodeConfig.ProductConfig != nil && len(sharedNodeConfig.ProductConfig.Name) > 0 {
return sharedNodeConfig.ProductConfig.Name
}
return teaconst.GlobalProductName
case "version":
if sharedNodeConfig.ProductConfig != nil && len(sharedNodeConfig.ProductConfig.Version) > 0 {
return sharedNodeConfig.ProductConfig.Version
}
return teaconst.Version
}
}
return "${" + varName + "}"
})
}
@@ -870,12 +977,17 @@ func (this *HTTPRequest) addVarMapping(varMapping map[string]string) {
// 获取请求的客户端地址
func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
if supportVar && len(this.remoteAddr) > 0 {
return this.remoteAddr
}
if supportVar &&
this.web.RemoteAddr != nil &&
this.web.RemoteAddr.IsOn &&
!this.web.RemoteAddr.IsEmpty() {
var remoteAddr = this.Format(this.web.RemoteAddr.Value)
if net.ParseIP(remoteAddr) != nil {
this.remoteAddr = remoteAddr
return remoteAddr
}
}
@@ -888,6 +1000,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
forwardedFor = forwardedFor[:commaIndex]
}
if net.ParseIP(forwardedFor) != nil {
if supportVar {
this.remoteAddr = forwardedFor
}
return forwardedFor
}
}
@@ -897,6 +1012,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
realIP, ok := this.RawReq.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
if net.ParseIP(realIP[0]) != nil {
if supportVar {
this.remoteAddr = realIP[0]
}
return realIP[0]
}
}
@@ -907,6 +1025,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
realIP, ok := this.RawReq.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
if net.ParseIP(realIP[0]) != nil {
if supportVar {
this.remoteAddr = realIP[0]
}
return realIP[0]
}
}
@@ -916,6 +1037,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
remoteAddr := this.RawReq.RemoteAddr
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if supportVar {
this.remoteAddr = host
}
return host
} else {
return remoteAddr
@@ -1089,14 +1213,22 @@ func (this *HTTPRequest) Id() string {
return this.requestId
}
func (this *HTTPRequest) Server() maps.Map {
return maps.Map{"id": this.ReqServer.Id}
}
func (this *HTTPRequest) Node() maps.Map {
return maps.Map{"id": teaconst.NodeId}
}
// URL 获取完整的URL
func (this *HTTPRequest) URL() string {
return this.requestScheme() + "://" + this.host + this.uri
return this.requestScheme() + "://" + this.ReqHost + this.uri
}
// Host 获取Host
func (this *HTTPRequest) Host() string {
return this.host
return this.ReqHost
}
func (this *HTTPRequest) Proto() string {
@@ -1151,6 +1283,7 @@ func (this *HTTPRequest) Method() string {
return this.RawReq.Method
}
// TransferEncoding 获取传输编码
func (this *HTTPRequest) TransferEncoding() string {
if len(this.RawReq.TransferEncoding) > 0 {
return this.RawReq.TransferEncoding[0]
@@ -1158,6 +1291,15 @@ func (this *HTTPRequest) TransferEncoding() string {
return ""
}
// Cookie 获取Cookie
func (this *HTTPRequest) Cookie(name string) string {
c, err := this.RawReq.Cookie(name)
if err != nil {
return ""
}
return c.Value
}
// DeleteHeader 删除Header
func (this *HTTPRequest) DeleteHeader(name string) {
this.RawReq.Header.Del(name)
@@ -1173,10 +1315,12 @@ func (this *HTTPRequest) Header() http.Header {
return this.RawReq.Header
}
// URI 获取当前请求的URI
func (this *HTTPRequest) URI() string {
return this.uri
}
// SetURI 设置当前请求的URI
func (this *HTTPRequest) SetURI(uri string) {
this.uri = uri
}
@@ -1186,6 +1330,29 @@ func (this *HTTPRequest) Done() {
this.isDone = true
}
// Close 关闭连接
func (this *HTTPRequest) Close() {
this.Done()
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
conn, ok := requestConn.(net.Conn)
if ok {
_ = conn.Close()
return
}
return
}
// Allow 放行
func (this *HTTPRequest) Allow() {
this.web.FirewallRef = nil
}
// 设置代理相关头部信息
// 参考https://tools.ietf.org/html/rfc7239
func (this *HTTPRequest) setForwardHeaders(header http.Header) {
@@ -1226,9 +1393,9 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
/**{
forwarded, ok := header["Forwarded"]
if ok {
header["Forwarded"] = []string{strings.Join(forwarded, ", ") + ", by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme}
header["Forwarded"] = []string{strings.Join(forwarded, ", ") + ", by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.ReqHost + "; proto=" + this.rawScheme}
} else {
header["Forwarded"] = []string{"by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme}
header["Forwarded"] = []string{"by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.ReqHost + "; proto=" + this.rawScheme}
}
}**/
@@ -1239,7 +1406,7 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
if this.reverseProxy != nil && this.reverseProxy.ShouldAddXForwardedHostHeader() {
if _, ok := header["X-Forwarded-Host"]; !ok {
this.RawReq.Header.Set("X-Forwarded-Host", this.host)
this.RawReq.Header.Set("X-Forwarded-Host", this.ReqHost)
}
}
@@ -1279,7 +1446,7 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
}
// 域名
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.host) {
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.ReqHost) {
continue
}
@@ -1363,7 +1530,7 @@ func (this *HTTPRequest) processResponseHeaders(statusCode int) {
}
// 域名
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.host) {
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.ReqHost) {
continue
}
@@ -1397,13 +1564,13 @@ func (this *HTTPRequest) processResponseHeaders(statusCode int) {
// HSTS
if this.IsHTTPS &&
this.Server.HTTPS != nil &&
this.Server.HTTPS.SSLPolicy != nil &&
this.Server.HTTPS.SSLPolicy.IsOn &&
this.Server.HTTPS.SSLPolicy.HSTS != nil &&
this.Server.HTTPS.SSLPolicy.HSTS.IsOn &&
this.Server.HTTPS.SSLPolicy.HSTS.Match(this.host) {
responseHeader.Set(this.Server.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.Server.HTTPS.SSLPolicy.HSTS.HeaderValue())
this.ReqServer.HTTPS != nil &&
this.ReqServer.HTTPS.SSLPolicy != nil &&
this.ReqServer.HTTPS.SSLPolicy.IsOn &&
this.ReqServer.HTTPS.SSLPolicy.HSTS != nil &&
this.ReqServer.HTTPS.SSLPolicy.HSTS.IsOn &&
this.ReqServer.HTTPS.SSLPolicy.HSTS.Match(this.ReqHost) {
responseHeader.Set(this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderValue())
}
}
@@ -1464,22 +1631,6 @@ func (this *HTTPRequest) canIgnore(err error) bool {
return false
}
// 关闭当前连接
func (this *HTTPRequest) closeConn() {
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
conn, ok := requestConn.(net.Conn)
if ok {
_ = conn.Close()
return
}
return
}
// 检查连接是否已关闭
func (this *HTTPRequest) isConnClosed() bool {
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)

View File

@@ -45,7 +45,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
if len(method.Realm) > 0 {
headerValue += method.Realm
} else {
headerValue += this.host
headerValue += this.ReqHost
}
headerValue += "\""
if len(method.Charset) > 0 {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
@@ -22,7 +21,7 @@ import (
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.cacheCanTryStale = false
cachePolicy := this.Server.HTTPCachePolicy
var cachePolicy = this.ReqServer.HTTPCachePolicy
if cachePolicy == nil || !cachePolicy.IsOn {
return
}
@@ -138,7 +137,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if err == nil {
for _, rpcServerService := range rpcClient.ServerRPCList() {
_, err = rpcServerService.PurgeServerCache(rpcClient.Context(), &pb.PurgeServerCacheRequest{
Domains: []string{this.host},
Domains: []string{this.ReqHost},
Keys: []string{key},
Prefixes: nil,
})
@@ -162,11 +161,15 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var err error
// 是否优先检查WebP
var isWebP = false
if this.web.WebP != nil &&
this.web.WebP.IsOn &&
this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) &&
this.web.WebP.MatchAccept(this.requestHeader("Accept")) {
reader, _ = storage.OpenReader(key+webpSuffix, useStale)
if reader != nil {
isWebP = true
}
}
// 检查正常的文件
@@ -184,13 +187,16 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", "read from cache failed: "+err.Error())
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
}
return
}
}
defer func() {
_ = reader.Close()
if !this.writer.DelayRead() {
_ = reader.Close()
}
}()
if useStale {
@@ -231,7 +237,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
})
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", "read from cache failed: "+err.Error())
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
}
return
}
@@ -257,7 +263,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
var eTag = ""
var lastModifiedAt = reader.LastModified()
if lastModifiedAt > 0 {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
if isWebP {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_webp" + "\""
} else {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
}
respHeader.Del("Etag")
respHeader["ETag"] = []string{eTag}
}
@@ -357,7 +367,6 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
}
respHeader := this.writer.Header()
if len(rangeSet) == 1 {
respHeader.Set("Content-Range", "bytes "+strconv.FormatInt(rangeSet[0][0], 10)+"-"+strconv.FormatInt(rangeSet[0][1], 10)+"/"+strconv.FormatInt(reader.BodySize(), 10))
respHeader.Set("Content-Length", strconv.FormatInt(rangeSet[0][1]-rangeSet[0][0]+1, 10))
@@ -379,7 +388,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", "read from cache failed: "+err.Error())
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return
}
@@ -425,7 +434,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
})
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", "read from cache failed: "+err.Error())
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return true
}
@@ -439,25 +448,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
} else { // 没有Range
var body io.Reader = reader
var contentEncoding = this.writer.Header().Get("Content-Encoding")
if len(contentEncoding) > 0 && !httpAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"), contentEncoding) {
decompressReader, err := compressions.NewReader(body, contentEncoding)
if err == nil {
body = decompressReader
defer func() {
_ = decompressReader.Close()
}()
this.writer.Header().Del("Content-Encoding")
this.writer.Header().Del("Content-Length")
}
}
this.writer.PrepareCompression(reader.BodySize())
var resp = &http.Response{Body: reader}
this.writer.Prepare(resp, reader.BodySize(), reader.Status(), false)
this.writer.WriteHeader(reader.Status())
_, err = io.CopyBuffer(this.writer, body, buf)
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
if err == io.EOF {
err = nil
}
@@ -465,7 +460,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", "read from cache failed: "+err.Error())
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
}
return
}

View File

@@ -52,13 +52,13 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
}
}
if !env.Has("SERVER_NAME") {
env["SERVER_NAME"] = this.host
env["SERVER_NAME"] = this.ReqHost
}
if !env.Has("REQUEST_URI") {
env["REQUEST_URI"] = this.uri
}
if !env.Has("HOST") {
env["HOST"] = this.host
env["HOST"] = this.ReqHost
}
if len(this.ServerAddr) > 0 {
@@ -149,7 +149,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
host, found := params["HTTP_HOST"]
if !found || len(host) == 0 {
params["HTTP_HOST"] = this.host
params["HTTP_HOST"] = this.ReqHost
}
fcgiReq := fcgi.NewRequest()
@@ -190,7 +190,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
this.processResponseHeaders(resp.StatusCode)
// 准备
this.writer.Prepare(resp.ContentLength, resp.StatusCode)
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
// 设置响应代码
this.writer.WriteHeader(resp.StatusCode)

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if this.web.MergeSlashes {
urlPath = utils.CleanPath(urlPath)
}
fullURL := this.requestScheme() + "://" + this.host + urlPath
fullURL := this.requestScheme() + "://" + this.ReqHost + urlPath
for _, u := range this.web.HostRedirects {
if !u.IsOn {
continue
@@ -73,6 +73,13 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
@@ -88,12 +95,20 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, u.AfterURL, http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
http.Redirect(this.RawWriter, this.RawReq, u.AfterURL, u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
}

View File

@@ -19,9 +19,9 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface)
if ok && !clientConn.IsBound() {
if !clientConn.Bind(this.Server.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
this.writeCode(http.StatusTooManyRequests)
this.closeConn()
this.Close()
return true
}
}

View File

@@ -93,7 +93,7 @@ func (this *HTTPRequest) log() {
accessLog := &pb.HTTPAccessLog{
RequestId: this.requestId,
NodeId: sharedNodeConfig.Id,
ServerId: this.Server.Id,
ServerId: this.ReqServer.Id,
RemoteAddr: this.requestRemoteAddr(true),
RawRemoteAddr: addr,
RemotePort: int32(this.requestRemotePort()),
@@ -114,7 +114,7 @@ func (this *HTTPRequest) log() {
TimeLocal: this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700"),
Msec: float64(this.requestFromTime.Unix()) + float64(this.requestFromTime.Nanosecond())/1000000000,
Timestamp: this.requestFromTime.Unix(),
Host: this.host,
Host: this.ReqHost,
Referer: referer,
UserAgent: userAgent,
Request: this.requestString(),

View File

@@ -50,7 +50,7 @@ func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
}
func (this *HTTPRequest) MetricServerId() int64 {
return this.Server.Id
return this.ReqServer.Id
}
func (this *HTTPRequest) MetricCategory() string {

View File

@@ -61,11 +61,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.writer.Prepare(stat.Size(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.writer.Prepare(stat.Size(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
buf := utils.BytePool1k.Get()
@@ -100,11 +100,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.writer.Prepare(int64(len(content)), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.writer.Prepare(int64(len(content)), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
@@ -36,12 +35,12 @@ func (this *HTTPRequest) doReverseProxy() {
requestCall := shared.NewRequestCall()
requestCall.Request = this.RawReq
requestCall.Formatter = this.Format
requestCall.Domain = this.host
requestCall.Domain = this.ReqHost
origin := this.reverseProxy.NextOrigin(requestCall)
requestCall.CallResponseCallbacks(this.writer)
if origin == nil {
err := errors.New(this.URL() + ": no available origin sites for reverse proxy")
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
this.write50x(err, http.StatusBadGateway, true)
return
}
@@ -129,7 +128,7 @@ func (this *HTTPRequest) doReverseProxy() {
this.RawReq.Host = hostname
this.RawReq.URL.Host = this.RawReq.Host
} else {
this.RawReq.URL.Host = this.host
this.RawReq.URL.Host = this.ReqHost
}
// 重组请求URL
@@ -195,7 +194,9 @@ func (this *HTTPRequest) doReverseProxy() {
} else {
this.write50x(err, http.StatusBadGateway, true)
}
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
if httpErr.Err != io.EOF {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
}
} else {
// 是否为客户端方面的错误
isClientError := false
@@ -260,39 +261,28 @@ func (this *HTTPRequest) doReverseProxy() {
}
}
// 解压
if !resp.Uncompressed {
var contentEncoding = resp.Header.Get("Content-Encoding")
if len(contentEncoding) > 0 && !httpAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"), contentEncoding) {
reader, err := compressions.NewReader(resp.Body, contentEncoding)
if err == nil {
var body = resp.Body
defer func() {
_ = body.Close()
}()
resp.Body = reader
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
}
}
}
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode)
// 是否需要刷新
shouldAutoFlush := this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"
// 准备
delayHeaders := this.writer.Prepare(resp.ContentLength, resp.StatusCode)
var delayHeaders = this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
// 设置响应代码
if !delayHeaders {
this.writer.WriteHeader(resp.StatusCode)
}
// 是否有内容
if resp.ContentLength == 0 && len(resp.TransferEncoding) == 0 {
_ = resp.Body.Close()
this.writer.SetOk()
return
}
// 输出到客户端
pool := this.bytePool(resp.ContentLength)
buf := pool.Get()

View File

@@ -15,7 +15,7 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeProxy {
// 外部URL
if this.rewriteIsExternalURL {
host := this.host
host := this.ReqHost
if len(this.rewriteRule.ProxyHost) > 0 {
host = this.rewriteRule.ProxyHost
}

View File

@@ -302,7 +302,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
this.cacheRef = nil // 不支持缓存
}
this.writer.Prepare(fileSize, http.StatusOK)
this.writer.Prepare(nil, fileSize, http.StatusOK, true)
pool := this.bytePool(fileSize)
buf := pool.Get()

View File

@@ -6,11 +6,11 @@ import (
// 统计
func (this *HTTPRequest) doStat() {
if this.Server == nil {
if this.ReqServer == nil {
return
}
// 内置的统计
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.Server.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
stats.SharedHTTPRequestStatManager.AddUserAgent(this.Server.Id, this.requestHeader("User-Agent"))
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"))
}

View File

@@ -10,8 +10,8 @@ func (this *HTTPRequest) doSubRequest(writer http.ResponseWriter, rawReq *http.R
req := &HTTPRequest{
RawReq: rawReq,
RawWriter: writer,
Server: this.Server,
host: this.host,
ReqServer: this.ReqServer,
ReqHost: this.ReqHost,
ServerName: this.ServerName,
ServerAddr: this.ServerAddr,
IsHTTP: this.IsHTTP,

View File

@@ -8,7 +8,7 @@ import (
// 流量限制
func (this *HTTPRequest) doTrafficLimit() {
var config = this.Server.TrafficLimit
var config = this.ReqServer.TrafficLimit
this.tags = append(this.tags, "bandwidth")

View File

@@ -54,9 +54,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
}
this.writer.AddHeaders(resp.Header)
if statusCode <= 0 {
this.writer.Prepare(resp.ContentLength, resp.StatusCode)
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
} else {
this.writer.Prepare(resp.ContentLength, statusCode)
this.writer.Prepare(resp, resp.ContentLength, statusCode, true)
}
// 设置响应代码

View File

@@ -17,6 +17,10 @@ import (
// 调用WAF
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
return
}
var remoteAddr = this.requestRemoteAddr(true)
// 检查是否为白名单直连
@@ -31,16 +35,16 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
}
// 是否在全局名单中
if !iplibrary.AllowIP(remoteAddr, this.Server.Id) {
if !iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) {
this.disableLog = true
this.closeConn()
this.Close()
return true
}
// 检查是否在临时黑名单中
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.Server.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
this.disableLog = true
this.closeConn()
this.Close()
return true
}
@@ -57,8 +61,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
}
// 公用的防火墙设置
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.Server.HTTPFirewallPolicy)
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy)
if blocked {
return true
}
@@ -208,7 +212,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
@@ -219,6 +223,10 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// call response waf
func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
return
}
// 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.web.FirewallPolicy, resp)
@@ -228,8 +236,8 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
}
// 公用的防火墙设置
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.Server.HTTPFirewallPolicy, resp)
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp)
if blocked {
return true
}
@@ -266,7 +274,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
@@ -313,12 +321,12 @@ func (this *HTTPRequest) WAFRestoreBody(data []byte) {
// WAFServerId 服务ID
func (this *HTTPRequest) WAFServerId() int64 {
return this.Server.Id
return this.ReqServer.Id
}
// WAFClose 关闭连接
func (this *HTTPRequest) WAFClose() {
this.closeConn()
this.Close()
}
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {

File diff suppressed because it is too large Load Diff

View File

@@ -1,102 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"bufio"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"time"
)
// HTTPRateWriter 限速写入
type HTTPRateWriter struct {
parentWriter http.ResponseWriter
rateBytes int
lastBytes int
timeCost time.Duration
}
func NewHTTPRateWriter(writer http.ResponseWriter, rateBytes int64) http.ResponseWriter {
return &HTTPRateWriter{
parentWriter: writer,
rateBytes: types.Int(rateBytes),
}
}
func (this *HTTPRateWriter) Header() http.Header {
return this.parentWriter.Header()
}
func (this *HTTPRateWriter) Write(data []byte) (int, error) {
if len(data) == 0 {
return 0, nil
}
var left = this.rateBytes - this.lastBytes
if left <= 0 {
if this.timeCost > 0 && this.timeCost < 1*time.Second {
time.Sleep(1*time.Second - this.timeCost)
}
this.lastBytes = 0
this.timeCost = 0
return this.Write(data)
}
var n = len(data)
// n <= left
if n <= left {
this.lastBytes += n
var before = time.Now()
defer func() {
this.timeCost += time.Since(before)
}()
return this.parentWriter.Write(data)
}
// n > left
var before = time.Now()
result, err := this.parentWriter.Write(data[:left])
this.timeCost += time.Since(before)
if err != nil {
return result, err
}
this.lastBytes += left
return this.Write(data[left:])
}
func (this *HTTPRateWriter) WriteHeader(statusCode int) {
this.parentWriter.WriteHeader(statusCode)
}
// Hijack Hijack
func (this *HTTPRateWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
if this.parentWriter == nil {
return
}
hijack, ok := this.parentWriter.(http.Hijacker)
if ok {
return hijack.Hijack()
}
return
}
// Flush Flush
func (this *HTTPRateWriter) Flush() {
if this.parentWriter == nil {
return
}
flusher, ok := this.parentWriter.(http.Flusher)
if ok {
flusher.Flush()
return
}
}

View File

@@ -61,7 +61,7 @@ func (this *Listener) listenTCP() error {
return err
}
var netListener = NewClientListener(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = netListener.Close()
})
@@ -122,7 +122,7 @@ func (this *Listener) listenUDP() error {
if err != nil {
return err
}
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = listener.Close()
})
@@ -143,6 +143,8 @@ func (this *Listener) listenUDP() error {
}
func (this *Listener) Close() error {
events.Remove(this)
if this.listener == nil {
return nil
}

View File

@@ -25,8 +25,8 @@ func (this *BaseListener) Reset() {
}
// CountActiveListeners 获取当前活跃连接数
func (this *BaseListener) CountActiveListeners() int {
// CountActiveConnections 获取当前活跃连接数
func (this *BaseListener) CountActiveConnections() int {
return types.Int(this.countActiveConnections)
}

View File

@@ -208,8 +208,8 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
req := &HTTPRequest{
RawReq: rawReq,
RawWriter: rawWriter,
Server: server,
host: reqHost,
ReqServer: server,
ReqHost: reqHost,
ServerName: serverName,
ServerAddr: this.addr,
IsHTTP: this.isHTTP,

View File

@@ -16,6 +16,6 @@ type ListenerInterface interface {
// Reload 重载配置
Reload(serverGroup *serverconfigs.ServerAddressGroup)
// CountActiveListeners 获取当前活跃的连接数
CountActiveListeners() int
// CountActiveConnections 获取当前活跃的连接数
CountActiveConnections() int
}

View File

@@ -14,6 +14,7 @@ import (
"os/exec"
"regexp"
"runtime"
"sort"
"strings"
"sync"
"time"
@@ -29,6 +30,8 @@ type ListenerManager struct {
retryListenerMap map[string]*Listener // 需要重试的监听器 addr => Listener
ticker *time.Ticker
lastPortStrings string
}
// NewListenerManager 获取新对象
@@ -143,6 +146,9 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
}
}
// 加入到firewalld
this.addToFirewalld(groupAddrs)
return nil
}
@@ -153,7 +159,7 @@ func (this *ListenerManager) TotalActiveConnections() int {
total := 0
for _, listener := range this.listenersMap {
total += listener.listener.CountActiveListeners()
total += listener.listener.CountActiveConnections()
}
return total
}
@@ -214,3 +220,63 @@ func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) st
}
return ""
}
func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
if !sharedNodeConfig.AutoOpenPorts {
return
}
// 组合端口号
var ports = []string{}
for _, addr := range groupAddrs {
var protocol = "tcp"
if strings.HasPrefix(addr, "udp") {
protocol = "udp"
}
var lastIndex = strings.LastIndex(addr, ":")
if lastIndex > 0 {
var portString = addr[lastIndex+1:]
ports = append(ports, portString+"/"+protocol)
}
}
if len(ports) == 0 {
return
}
// 检查是否有变化
sort.Strings(ports)
var newPortStrings = strings.Join(ports, ",")
if newPortStrings == this.lastPortStrings {
return
}
this.lastPortStrings = newPortStrings
firewallCmd, err := exec.LookPath("firewall-cmd")
if err != nil || len(firewallCmd) == 0 {
return
}
remotelogs.Println("FIREWALLD", "open ports automatically")
for _, port := range ports {
{
// TODO 需要支持sudo
var cmd = exec.Command(firewallCmd, "--add-port="+port, "--permanent")
err = cmd.Run()
if err != nil {
remotelogs.Warn("FIREWALLD", "'"+cmd.String()+"': "+err.Error())
return
}
}
{
// TODO 需要支持sudo
var cmd = exec.Command(firewallCmd, "--add-port="+port)
err = cmd.Run()
if err != nil {
remotelogs.Warn("FIREWALLD", "'"+cmd.String()+"': "+err.Error())
return
}
}
}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/configs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
@@ -55,12 +56,16 @@ type Node struct {
maxCPU int32
maxThreads int
timezone string
updatingServerMap map[int64]*serverconfigs.ServerConfig
}
func NewNode() *Node {
return &Node{
sock: gosock.NewTmpSock(teaconst.ProcessName),
maxThreads: -1,
sock: gosock.NewTmpSock(teaconst.ProcessName),
maxThreads: -1,
maxCPU: -1,
updatingServerMap: map[int64]*serverconfigs.ServerConfig{},
}
}
@@ -262,7 +267,7 @@ func (this *Node) loop() error {
defer tr.End()
// 检查api.yaml是否存在
apiConfigFile := Tea.ConfigFile("api.yaml")
var apiConfigFile = Tea.ConfigFile("api.yaml")
_, err := os.Stat(apiConfigFile)
if err != nil {
return nil
@@ -273,7 +278,7 @@ func (this *Node) loop() error {
return errors.New("create rpc client failed: " + err.Error())
}
nodeCtx := rpcClient.Context()
var nodeCtx = rpcClient.Context()
tasksResp, err := rpcClient.NodeTaskRPC().FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{})
if err != nil {
return errors.New("read node tasks failed: " + err.Error())
@@ -293,11 +298,15 @@ func (this *Node) loop() error {
return err
}
case "configChanged":
if !task.IsPrimary {
// 我们等等主节点配置准备完毕
time.Sleep(2 * time.Second)
if task.ServerId > 0 {
err = this.syncServerConfig(task.ServerId)
} else {
if !task.IsPrimary {
// 我们等等主节点配置准备完毕
time.Sleep(2 * time.Second)
}
err = this.syncConfig(task.Version)
}
err := this.syncConfig(task.Version)
if err != nil {
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
@@ -314,6 +323,7 @@ func (this *Node) loop() error {
if err != nil {
return err
}
case "nodeVersionChanged":
goman.New(func() {
sharedUpgradeManager.Start()
@@ -417,22 +427,8 @@ func (this *Node) syncConfig(taskVersion int64) error {
remotelogs.Println("NODE", "loading config ...")
}
nodeconfigs.ResetNodeConfig(nodeConfig)
caches.SharedManager.MaxDiskCapacity = nodeConfig.MaxCacheDiskCapacity
caches.SharedManager.MaxMemoryCapacity = nodeConfig.MaxCacheMemoryCapacity
if len(nodeConfig.HTTPCachePolicies) > 0 {
caches.SharedManager.UpdatePolicies(nodeConfig.HTTPCachePolicies)
} else {
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
}
sharedWAFManager.UpdatePolicies(nodeConfig.FindAllFirewallPolicies())
iplibrary.SharedActionManager.UpdateActions(nodeConfig.FirewallActions)
sharedNodeConfig = nodeConfig
this.onReload(nodeConfig)
metrics.SharedManager.Update(nodeConfig.MetricItems)
// 发送事件
events.Notify(events.EventReload)
@@ -445,30 +441,96 @@ func (this *Node) syncConfig(taskVersion int64) error {
return nil
}
// 读取单个服务配置
func (this *Node) syncServerConfig(serverId int64) error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.ServerRPC().ComposeServerConfig(rpcClient.Context(), &pb.ComposeServerConfigRequest{ServerId: serverId})
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
if len(resp.ServerConfigJSON) == 0 {
this.updatingServerMap[serverId] = nil
} else {
var config = &serverconfigs.ServerConfig{}
err = json.Unmarshal(resp.ServerConfigJSON, config)
if err != nil {
return err
}
this.updatingServerMap[serverId] = config
}
return nil
}
// 启动同步计时器
func (this *Node) startSyncTimer() {
// TODO 这个时间间隔可以自行设置
ticker := time.NewTicker(60 * time.Second)
events.On(events.EventQuit, func() {
var taskTicker = time.NewTicker(60 * time.Second)
var serverChangeTicker = time.NewTicker(5 * time.Second)
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("NODE", "quit sync timer")
ticker.Stop()
taskTicker.Stop()
serverChangeTicker.Stop()
})
goman.New(func() {
for {
select {
case <-ticker.C:
case <-taskTicker.C: // 定期执行
err := this.loop()
if err != nil {
remotelogs.Error("NODE", "sync config error: "+err.Error())
continue
}
case <-nodeTaskNotify:
case <-serverChangeTicker.C: // 服务变化
this.locker.Lock()
if len(this.updatingServerMap) > 0 {
var updatingServerMap = this.updatingServerMap
this.updatingServerMap = map[int64]*serverconfigs.ServerConfig{}
newNodeConfig, err := nodeconfigs.CloneNodeConfig(sharedNodeConfig)
if err != nil {
remotelogs.Error("NODE", "apply server config error: "+err.Error())
continue
}
for serverId, serverConfig := range updatingServerMap {
if serverConfig != nil {
newNodeConfig.AddServer(serverConfig)
} else {
newNodeConfig.RemoveServer(serverId)
}
}
err, serverErrors := newNodeConfig.Init()
if err != nil {
remotelogs.Error("NODE", "apply server config error: "+err.Error())
continue
}
if len(serverErrors) > 0 {
for _, serverErr := range serverErrors {
remotelogs.ServerError(serverErr.Id, "NODE", serverErr.Message, nodeconfigs.NodeLogTypeServerConfigInitFailed, maps.Map{})
}
}
this.onReload(newNodeConfig)
err = sharedListenerManager.Start(newNodeConfig)
if err != nil {
remotelogs.Error("NODE", "apply server config error: "+err.Error())
}
}
this.locker.Unlock()
case <-nodeTaskNotify: // 有新的更新任务
err := this.loop()
if err != nil {
remotelogs.Error("NODE", "sync config error: "+err.Error())
continue
}
case <-nodeConfigChangedNotify:
case <-nodeConfigChangedNotify: // 节点变化通知
err := this.syncConfig(0)
if err != nil {
remotelogs.Error("NODE", "sync config error: "+err.Error())
@@ -635,6 +697,51 @@ func (this *Node) listenSock() error {
"limiter": sharedConnectionsLimiter.Len(),
},
})
case "dropIP":
var m = maps.NewMap(cmd.Params)
var ip = m.GetString("ip")
var timeSeconds = m.GetInt("timeoutSeconds")
err := firewalls.Firewall().DropSourceIP(ip, timeSeconds)
if err != nil {
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{
"error": err.Error(),
},
})
} else {
_ = cmd.ReplyOk()
}
case "rejectIP":
var m = maps.NewMap(cmd.Params)
var ip = m.GetString("ip")
var timeSeconds = m.GetInt("timeoutSeconds")
err := firewalls.Firewall().RejectSourceIP(ip, timeSeconds)
if err != nil {
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{
"error": err.Error(),
},
})
} else {
_ = cmd.ReplyOk()
}
case "removeIP":
var m = maps.NewMap(cmd.Params)
var ip = m.GetString("ip")
err := firewalls.Firewall().RemoveSourceIP(ip)
if err != nil {
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{
"error": err.Error(),
},
})
} else {
_ = cmd.ReplyOk()
}
case "gc":
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
}
})
@@ -644,7 +751,7 @@ func (this *Node) listenSock() error {
}
})
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
logs.Println("NODE", "quit unix sock")
_ = this.sock.Close()
})
@@ -654,14 +761,34 @@ func (this *Node) listenSock() error {
// 重载配置调用
func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
nodeconfigs.ResetNodeConfig(config)
sharedNodeConfig = config
// 缓存策略
caches.SharedManager.MaxDiskCapacity = config.MaxCacheDiskCapacity
caches.SharedManager.MaxMemoryCapacity = config.MaxCacheMemoryCapacity
if len(config.HTTPCachePolicies) > 0 {
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
} else {
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
}
// WAF策略
sharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
// 统计指标
metrics.SharedManager.Update(config.MetricItems)
// max cpu
if config.MaxCPU != this.maxCPU {
if config.MaxCPU > 0 && config.MaxCPU < int32(runtime.NumCPU()) {
runtime.GOMAXPROCS(int(config.MaxCPU))
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(config.MaxCPU)+"'")
} else {
runtime.GOMAXPROCS(runtime.NumCPU())
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(runtime.NumCPU())+"'")
var threads = runtime.NumCPU() * 4
runtime.GOMAXPROCS(threads)
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(threads)+"'")
}
this.maxCPU = config.MaxCPU
@@ -707,4 +834,9 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
time.Local = location
this.timezone = timeZone
}
// product information
if config.ProductConfig != nil {
teaconst.GlobalProductName = config.ProductConfig.Name
}
}

View File

@@ -41,9 +41,9 @@ func (this *NodeStatusExecutor) Listen() {
this.update()
// TODO 这个时间间隔可以配置
ticker := time.NewTicker(30 * time.Second)
var ticker = time.NewTicker(30 * time.Second)
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("NODE_STATUS", "quit executor")
ticker.Stop()
})

View File

@@ -21,6 +21,9 @@ func init() {
SharedOriginStateManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedOriginStateManager.Stop()
})
}
// OriginStateManager 源站状态管理
@@ -41,7 +44,7 @@ func NewOriginStateManager() *OriginStateManager {
// Start 启动
func (this *OriginStateManager) Start() {
events.On(events.EventReload, func() {
events.OnKey(events.EventReload, this, func() {
this.locker.Lock()
this.stateMap = map[int64]*OriginState{}
this.locker.Unlock()
@@ -58,6 +61,12 @@ func (this *OriginStateManager) Start() {
}
}
func (this *OriginStateManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
// Loop 单次循环检查
func (this *OriginStateManager) Loop() error {
if sharedNodeConfig == nil {

View File

@@ -44,10 +44,22 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
// TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用
// TODO 支持使用证书
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
var tlsConfig = &tls.Config{
InsecureSkipVerify: true,
})
}
if origin.Cert != nil {
var obj = origin.Cert.CertObject()
if obj != nil {
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{*obj}
if len(origin.Cert.ServerName) > 0 {
tlsConfig.ServerName = origin.Cert.ServerName
}
}
}
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
}
// TODO 需要在合适的时机删除TOA记录
@@ -69,10 +81,22 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
// TODO 支持TCP4/TCP6
// TODO 支持指定特定网卡
// TODO Addr支持端口范围如果有多个端口时随机一个端口使用
// TODO 支持使用证书
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
var tlsConfig = &tls.Config{
InsecureSkipVerify: true,
})
}
if origin.Cert != nil {
var obj = origin.Cert.CertObject()
if obj != nil {
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{*obj}
if len(origin.Cert.ServerName) > 0 {
tlsConfig.ServerName = origin.Cert.ServerName
}
}
}
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, tlsConfig)
case serverconfigs.ProtocolUDP:
addr, err := net.ResolveUDPAddr("udp", origin.Addr.Host+":"+origin.Addr.PortRange)
if err != nil {

View File

@@ -28,7 +28,7 @@ func init() {
})
}
// 系统服务管理
// SystemServiceManager 系统服务管理
type SystemServiceManager struct {
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/configs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/iwind/TeaGo/Tea"
@@ -21,17 +20,22 @@ import (
"time"
)
var sharedSyncAPINodesTask = NewSyncAPINodesTask()
func init() {
events.On(events.EventStart, func() {
task := NewSyncAPINodesTask()
goman.New(func() {
task.Start()
sharedSyncAPINodesTask.Start()
})
})
events.On(events.EventQuit, func() {
sharedSyncAPINodesTask.Stop()
})
}
// SyncAPINodesTask API节点同步任务
type SyncAPINodesTask struct {
ticker *time.Ticker
}
func NewSyncAPINodesTask() *SyncAPINodesTask {
@@ -39,16 +43,12 @@ func NewSyncAPINodesTask() *SyncAPINodesTask {
}
func (this *SyncAPINodesTask) Start() {
ticker := time.NewTicker(5 * time.Minute)
this.ticker = time.NewTicker(5 * time.Minute)
if Tea.IsTesting() {
// 快速测试
ticker = time.NewTicker(1 * time.Minute)
this.ticker = time.NewTicker(1 * time.Minute)
}
events.On(events.EventQuit, func() {
remotelogs.Println("SYNC_API_NODES_TASK", "quit task")
ticker.Stop()
})
for range ticker.C {
for range this.ticker.C {
err := this.Loop()
if err != nil {
logs.Println("[TASK][SYNC_API_NODES_TASK]" + err.Error())
@@ -56,6 +56,12 @@ func (this *SyncAPINodesTask) Start() {
}
}
func (this *SyncAPINodesTask) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *SyncAPINodesTask) Loop() error {
var tr = trackers.Begin("SYNC_API_NODES")
defer tr.End()

View File

@@ -32,14 +32,13 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
m := map[int64]*waf.WAF{}
for _, p := range policies {
w, err := this.convertWAF(p)
if w != nil {
m[p.Id] = w
}
if err != nil {
remotelogs.Error("WAF", "initialize policy '"+strconv.FormatInt(p.Id, 10)+"' failed: "+err.Error())
continue
}
if w == nil {
continue
}
m[p.Id] = w
}
this.mapping = m
}
@@ -61,10 +60,12 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
policy.Mode = firewallconfigs.FirewallModeDefend
}
w := &waf.WAF{
Id: policy.Id,
IsOn: policy.IsOn,
Name: policy.Name,
Mode: policy.Mode,
Id: policy.Id,
IsOn: policy.IsOn,
Name: policy.Name,
Mode: policy.Mode,
UseLocalFirewall: policy.UseLocalFirewall,
SYNFlood: policy.SYNFlood,
}
// inbound
@@ -181,9 +182,9 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
}
}
err := w.Init()
if err != nil {
return nil, err
errorList := w.Init()
if len(errorList) > 0 {
return w, errorList[0]
}
return w, nil

254
internal/re/regexp.go Normal file
View File

@@ -0,0 +1,254 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package re
import (
"regexp"
"strings"
)
var prefixReg = regexp.MustCompile(`^\(\?([\w\s]+)\)`) // (?x)
var prefixReg2 = regexp.MustCompile(`^\(\?([\w\s]*:)`) // (?x: ...
var braceZero = regexp.MustCompile(`^{\s*0*\s*}`) // {0}
var braceZero2 = regexp.MustCompile(`^{\s*0*\s*,`) // {0, x}
type Regexp struct {
exp string
rawRegexp *regexp.Regexp
isStrict bool
isCaseInsensitive bool
keywords []string
keywordsMap RuneMap
}
func MustCompile(exp string) *Regexp {
var reg = &Regexp{
exp: exp,
rawRegexp: regexp.MustCompile(exp),
}
reg.init()
return reg
}
func Compile(exp string) (*Regexp, error) {
reg, err := regexp.Compile(exp)
if err != nil {
return nil, err
}
return NewRegexp(reg), nil
}
func NewRegexp(rawRegexp *regexp.Regexp) *Regexp {
var reg = &Regexp{
exp: rawRegexp.String(),
rawRegexp: rawRegexp,
}
reg.init()
return reg
}
func (this *Regexp) init() {
if len(this.exp) == 0 {
return
}
//var keywords = []string{}
var exp = strings.TrimSpace(this.exp)
// 去掉前面的(?...)
if prefixReg.MatchString(exp) {
var matches = prefixReg.FindStringSubmatch(exp)
var modifiers = matches[1]
if strings.Contains(modifiers, "i") {
this.isCaseInsensitive = true
}
exp = exp[len(matches[0]):]
}
var keywords = this.ParseKeywords(exp)
this.keywords = keywords
if len(keywords) > 0 {
this.keywordsMap = NewRuneTree(keywords)
}
}
func (this *Regexp) Keywords() []string {
return this.keywords
}
func (this *Regexp) Raw() *regexp.Regexp {
return this.rawRegexp
}
func (this *Regexp) IsCaseInsensitive() bool {
return this.isCaseInsensitive
}
func (this *Regexp) MatchString(s string) bool {
if this.keywordsMap != nil {
var b = this.keywordsMap.Lookup(s, this.isCaseInsensitive)
if !b {
return false
}
if this.isStrict {
return true
}
}
return this.rawRegexp.MatchString(s)
}
func (this *Regexp) Match(s []byte) bool {
if this.keywordsMap != nil {
var b = this.keywordsMap.Lookup(string(s), this.isCaseInsensitive)
if !b {
return false
}
if this.isStrict {
return true
}
}
return this.rawRegexp.Match(s)
}
// ParseKeywords 提取表达式中的关键词
// TODO 支持嵌套,类似于 A(abc|bcd)
// TODO 支持 (?:xxx)
// TODO 支持 abc)(bcd)(efg)
func (this *Regexp) ParseKeywords(exp string) []string {
var keywords = []string{}
if len(exp) == 0 {
return nil
}
var runes = []rune(exp)
// (a|b|c)
reg, err := regexp.Compile(exp)
if err == nil {
var countSub = reg.NumSubexp()
if countSub == 1 {
beginIndex := this.indexOfSymbol(runes, '(')
if beginIndex >= 0 {
runes = runes[beginIndex+1:]
symbolIndex := this.indexOfSymbol(runes, ')')
if symbolIndex > 0 && this.isPlain(runes[symbolIndex+1:]) {
runes = runes[:symbolIndex]
if len(runes) == 0 {
return nil
}
}
}
}
}
var lastIndex = 0
for index, r := range runes {
if r == '|' {
if index > 0 && runes[index-1] != '\\' {
var ks = this.parseKeyword(runes[lastIndex:index])
if len(ks) > 0 {
keywords = append(keywords, string(ks))
} else {
return nil
}
lastIndex = index + 1
}
}
}
if lastIndex == 0 {
var ks = this.parseKeyword(runes)
if len(ks) > 0 {
keywords = append(keywords, string(ks))
} else {
return nil
}
} else if lastIndex > 0 {
var ks = this.parseKeyword(runes[lastIndex:])
if len(ks) > 0 {
keywords = append(keywords, string(ks))
} else {
return nil
}
}
return keywords
}
func (this *Regexp) parseKeyword(keyword []rune) (result []rune) {
if len(keyword) == 0 {
return
}
// remove first \b
for index, r := range keyword {
if r == '\b' {
keyword = keyword[index+1:]
break
} else if r != '\t' && r != '\r' && r != '\n' && r != ' ' {
break
}
}
if len(keyword) == 0 {
return
}
for index, r := range keyword {
if index == 0 && r == '^' {
continue
}
if r == '(' || r == ')' {
if index == 0 {
return nil
}
if keyword[index-1] != '\\' {
return nil
}
}
if r == '[' || r == '{' || r == '.' || r == '+' || r == '$' {
if index == 0 {
return nil
}
if keyword[index-1] != '\\' {
if r == '{' && (braceZero.MatchString(string(keyword[index:])) || braceZero2.MatchString(string(keyword[index:]))) { // r {0, ...}
return result[:len(result)-1]
}
return
}
}
if r == '?' || r == '*' {
if index == 0 {
return nil
}
return result[:len(result)-1]
}
if r == '\\' || r == '\b' {
// TODO 将来更精细的处理 \d, \s, \$等
break
}
result = append(result, r)
}
return
}
// 查找符号位置
func (this *Regexp) indexOfSymbol(runes []rune, symbol rune) int {
for index, c := range runes {
if c == symbol && (index == 0 || runes[index-1] != '\\') {
return index
}
}
return -1
}
// 是否可视为为普通字符
func (this *Regexp) isPlain(runes []rune) bool {
for _, r := range []rune{'|', '(', ')'} {
if this.indexOfSymbol(runes, r) >= 0 {
return false
}
}
return true
}

120
internal/re/regexp_test.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package re_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/iwind/TeaGo/assert"
"regexp"
"testing"
)
func TestRegexp(t *testing.T) {
for _, s := range []string{"(?i)(abc|efg)", "abc|efg", "abc(.+)"} {
var reg = regexp.MustCompile(s)
t.Log("===" + s + "===")
t.Log(reg.LiteralPrefix())
t.Log(reg.NumSubexp())
t.Log(reg.SubexpNames())
}
}
func TestRegexp_MatchString(t *testing.T) {
var a = assert.NewAssertion(t)
{
var r = re.MustCompile("abc")
a.IsTrue(r.MatchString("abc"))
a.IsFalse(r.MatchString("ab"))
}
{
var r = re.MustCompile("(?i)abc|def|ghi")
a.IsTrue(r.MatchString("DEF"))
a.IsFalse(r.MatchString("ab"))
}
}
func TestRegexp_Sub(t *testing.T) {
{
reg := regexp.MustCompile(`(a|b|c)(e|f|g)`)
for _, subName := range reg.SubexpNames() {
t.Log(subName)
}
}
}
func TestRegexp_ParseKeywords(t *testing.T) {
var a = assert.NewAssertion(t)
var r = re.MustCompile("")
a.IsTrue(testCompareStrings(r.ParseKeywords("(abc)def"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("(abc)|(?:def)"), []string{}))
a.IsTrue(testCompareStrings(r.ParseKeywords("(abc)|def"), []string{}))
a.IsTrue(testCompareStrings(r.ParseKeywords("(abc)"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("(?i:abc)"), []string{}))
a.IsTrue(testCompareStrings(r.ParseKeywords("\babc"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords(" \babc"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("\babc\b"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("\b(abc)"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc|efg|hij"), []string{"abc", "efg", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg*|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg?|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg+|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg{2,10}|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg{0,10}|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\|efg.+|hij"), []string{"abc", "hij"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("A(abc|bcd)"), []string{"abc", "bcd"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("^abc"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc$"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\$"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc\\d"), []string{"abc"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("abc{0,4}"), []string{"ab"}))
a.IsTrue(testCompareStrings(r.ParseKeywords("{0,4}"), []string{}))
a.IsTrue(testCompareStrings(r.ParseKeywords("{1,4}"), []string{}))
a.IsTrue(testCompareStrings(r.ParseKeywords("中文|北京|上海|golang"), []string{"中文", "北京", "上海", "golang"}))
}
func TestRegexp_ParseKeywords2(t *testing.T) {
var r = re.MustCompile("")
var policy = firewallconfigs.HTTPFirewallTemplate()
for _, group := range policy.Inbound.Groups {
for _, set := range group.Sets {
for _, rule := range set.Rules {
if rule.Operator == firewallconfigs.HTTPFirewallRuleOperatorMatch || rule.Operator == firewallconfigs.HTTPFirewallRuleOperatorNotMatch {
t.Log(set.Name+":", rule.Value, "=>", r.ParseKeywords(rule.Value))
}
}
}
}
}
func BenchmarkRegexp_MatchString(b *testing.B) {
var r = re.MustCompile("(?i)abc|def|ghi")
for i := 0; i < b.N; i++ {
r.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36")
}
}
func BenchmarkRegexp_MatchString2(b *testing.B) {
var r = regexp.MustCompile("(?i)abc|def|ghi")
for i := 0; i < b.N; i++ {
r.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36")
}
}
func testCompareStrings(s1 []string, s2 []string) bool {
if len(s1) != len(s2) {
return false
}
for index, s := range s1 {
if s != s2[index] {
return false
}
}
return true
}

74
internal/re/rune_tree.go Normal file
View File

@@ -0,0 +1,74 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package re
type RuneMap map[rune]*RuneTree
func (this *RuneMap) Lookup(s string, caseInsensitive bool) bool {
return this.lookup([]rune(s), caseInsensitive, 0)
}
func (this RuneMap) lookup(runes []rune, caseInsensitive bool, depth int) bool {
if len(runes) == 0 {
return false
}
for i, r := range runes {
tree, ok := this[r]
if !ok {
if caseInsensitive {
if r >= 'a' && r <= 'z' {
r -= 32
tree, ok = this[r]
} else if r >= 'A' && r <= 'Z' {
r += 32
tree, ok = this[r]
}
}
if !ok {
if depth > 0 {
return false
}
continue
}
}
if tree.IsEnd {
return true
}
b := tree.Children.lookup(runes[i+1:], caseInsensitive, depth+1)
if b {
return true
}
}
return false
}
type RuneTree struct {
Children RuneMap
IsEnd bool
}
func NewRuneTree(list []string) RuneMap {
var rootMap = RuneMap{}
for _, s := range list {
if len(s) == 0 {
continue
}
var lastMap = rootMap
var runes = []rune(s)
for index, r := range runes {
tree, ok := lastMap[r]
if !ok {
tree = &RuneTree{
Children: RuneMap{},
}
lastMap[r] = tree
}
if index == len(runes)-1 {
tree.IsEnd = true
}
lastMap = tree.Children
}
}
return rootMap
}

View File

@@ -0,0 +1,47 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package re_test
import (
"github.com/TeaOSLab/EdgeNode/internal/re"
"github.com/iwind/TeaGo/assert"
"regexp"
"testing"
)
func TestNewRuneTree(t *testing.T) {
var a = assert.NewAssertion(t)
var tree = re.NewRuneTree([]string{"abc", "abd", "def", "GHI", "中国", "@"})
a.IsTrue(tree.Lookup("ABC", true))
a.IsTrue(tree.Lookup("ABC1", true))
a.IsTrue(tree.Lookup("1ABC", true))
a.IsTrue(tree.Lookup("def", true))
a.IsTrue(tree.Lookup("ghI", true))
a.IsFalse(tree.Lookup("d ef", true))
a.IsFalse(tree.Lookup("de", true))
a.IsFalse(tree.Lookup("de f", true))
a.IsTrue(tree.Lookup("我是中国人", true))
a.IsTrue(tree.Lookup("iwind.liu@gmail.com", true))
}
func BenchmarkRuneMap_Lookup(b *testing.B) {
var tree = re.NewRuneTree([]string{"abc", "abd", "def", "ghi", "中国"})
for i := 0; i < b.N; i++ {
tree.Lookup("我来自中国", true)
}
}
func BenchmarkRuneMap_Lookup2_NOT_FOUND(b *testing.B) {
var tree = re.NewRuneTree([]string{"abc", "abd", "cde", "GHI"})
for i := 0; i < b.N; i++ {
tree.Lookup("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36", true)
}
}
func BenchmarkRune_Regexp_FOUND(b *testing.B) {
var reg = regexp.MustCompile("(?i)abc|abd|cde|GHI")
for i := 0; i < b.N; i++ {
reg.MatchString("HELLO WORLD ABC 123 456 abc HELLO WORLD HELLO WORLD ABC 123 456 abc HELLO WORLD HELLO WORLD ABC 123 456 abc HELLO WORLD")
}
}

View File

@@ -173,6 +173,39 @@ func ServerSuccess(serverId int64, tag string, description string, logType nodec
}
}
// ServerLog 打印服务相关日志信息
func ServerLog(serverId int64, tag string, description string, logType nodeconfigs.NodeLogType, params maps.Map) {
logs.Println("[" + tag + "]" + description)
// 参数
var paramsJSON []byte
if len(params) > 0 {
p, err := json.Marshal(params)
if err != nil {
logs.Println("[LOG]" + err.Error())
} else {
paramsJSON = p
}
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: "info",
NodeId: teaconst.NodeId,
ServerId: serverId,
CreatedAt: time.Now().Unix(),
Type: logType,
ParamsJSON: paramsJSON,
}:
default:
}
}
// 上传日志
func uploadLogs() error {
logList := []*pb.NodeLog{}

View File

@@ -81,6 +81,14 @@ func (this *RPCClient) RegionProvinceRPC() pb.RegionProvinceServiceClient {
return pb.NewRegionProvinceServiceClient(this.pickConn())
}
func (this *RPCClient) RegionCityRPC() pb.RegionCityServiceClient {
return pb.NewRegionCityServiceClient(this.pickConn())
}
func (this *RPCClient) RegionProviderRPC() pb.RegionProviderServiceClient {
return pb.NewRegionProviderServiceClient(this.pickConn())
}
func (this *RPCClient) IPListRPC() pb.IPListServiceClient {
return pb.NewIPListServiceClient(this.pickConn())
}

View File

@@ -10,12 +10,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"github.com/mssola/user_agent"
"strconv"
"strings"
"time"
@@ -64,25 +64,26 @@ func NewHTTPRequestStatManager() *HTTPRequestStatManager {
// Start 启动
func (this *HTTPRequestStatManager) Start() {
// 上传请求总数
var monitorTicker = time.NewTicker(1 * time.Minute)
events.OnKey(events.EventQuit, this, func() {
monitorTicker.Stop()
})
goman.New(func() {
ticker := time.NewTicker(1 * time.Minute)
goman.New(func() {
for range ticker.C {
if this.totalAttackRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAttackRequests, maps.Map{"total": this.totalAttackRequests})
this.totalAttackRequests = 0
}
for range monitorTicker.C {
if this.totalAttackRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAttackRequests, maps.Map{"total": this.totalAttackRequests})
this.totalAttackRequests = 0
}
})
}
})
loopTicker := time.NewTicker(1 * time.Second)
uploadTicker := time.NewTicker(30 * time.Minute)
var loopTicker = time.NewTicker(1 * time.Second)
var uploadTicker = time.NewTicker(30 * time.Minute)
if Tea.IsTesting() {
uploadTicker = time.NewTicker(10 * time.Second) // 在测试环境下缩短Ticker时间以方便我们调试
}
remotelogs.Println("HTTP_REQUEST_STAT_MANAGER", "start ...")
events.On(events.EventQuit, func() {
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("HTTP_REQUEST_STAT_MANAGER", "quit")
loopTicker.Stop()
uploadTicker.Stop()
@@ -177,7 +178,6 @@ func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firew
// Loop 单个循环
func (this *HTTPRequestStatManager) Loop() error {
timeout := time.NewTimer(10 * time.Minute) // 执行的最大时间
userAgentParser := &user_agent.UserAgent{}
Loop:
for {
select {
@@ -223,8 +223,8 @@ Loop:
serverId := userAgentString[:atIndex]
userAgent := userAgentString[atIndex+1:]
userAgentParser.Parse(userAgent)
osInfo := userAgentParser.OSInfo()
var result = SharedUserAgentParser.Parse(userAgent)
var osInfo = result.OS
if len(osInfo.Name) > 0 {
dotIndex := strings.Index(osInfo.Version, ".")
if dotIndex > -1 {
@@ -233,7 +233,7 @@ Loop:
this.systemMap[serverId+"@"+osInfo.Name+"@"+osInfo.Version]++
}
browser, browserVersion := userAgentParser.Browser()
var browser, browserVersion = result.BrowserName, result.BrowserVersion
if len(browser) > 0 {
dotIndex := strings.Index(browserVersion, ".")
if dotIndex > -1 {
@@ -320,6 +320,15 @@ func (this *HTTPRequestStatManager) Upload() error {
})
}
// 重置数据
// 这里需要放到上传数据之前,防止因上传失败而导致统计数据堆积
this.cityMap = map[string]*StatItem{}
this.providerMap = map[string]int64{}
this.systemMap = map[string]int64{}
this.browserMap = map[string]int64{}
this.dailyFirewallRuleGroupMap = map[string]int64{}
// 上传数据
_, err = rpcClient.ServerRPC().UploadServerHTTPRequestStat(rpcClient.Context(), &pb.UploadServerHTTPRequestStatRequest{
Month: timeutil.Format("Ym"),
Day: timeutil.Format("Ymd"),
@@ -330,14 +339,30 @@ func (this *HTTPRequestStatManager) Upload() error {
HttpFirewallRuleGroups: pbFirewallRuleGroups,
})
if err != nil {
return err
// 是否包含了invalid UTF-8
if strings.Contains(err.Error(), "string field contains invalid UTF-8") {
for _, system := range pbSystems {
system.Name = utils.ToValidUTF8string(system.Name)
}
for _, browser := range pbBrowsers {
browser.Name = utils.ToValidUTF8string(browser.Name)
}
// 再次尝试
_, err = rpcClient.ServerRPC().UploadServerHTTPRequestStat(rpcClient.Context(), &pb.UploadServerHTTPRequestStatRequest{
Month: timeutil.Format("Ym"),
Day: timeutil.Format("Ymd"),
RegionCities: pbCities,
RegionProviders: pbProviders,
Systems: pbSystems,
Browsers: pbBrowsers,
HttpFirewallRuleGroups: pbFirewallRuleGroups,
})
if err != nil {
return err
}
}
}
// 重置数据
this.cityMap = map[string]*StatItem{}
this.providerMap = map[string]int64{}
this.systemMap = map[string]int64{}
this.browserMap = map[string]int64{}
this.dailyFirewallRuleGroupMap = map[string]int64{}
return nil
}

View File

@@ -56,16 +56,17 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
this.configFunc = configFunc
// 上传请求总数
var monitorTicker = time.NewTicker(1 * time.Minute)
events.OnKey(events.EventQuit, this, func() {
monitorTicker.Stop()
})
goman.New(func() {
ticker := time.NewTicker(1 * time.Minute)
goman.New(func() {
for range ticker.C {
if this.totalRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemRequests, maps.Map{"total": this.totalRequests})
this.totalRequests = 0
}
for range monitorTicker.C {
if this.totalRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemRequests, maps.Map{"total": this.totalRequests})
this.totalRequests = 0
}
})
}
})
// 上传统计数据
@@ -74,8 +75,8 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
// 测试环境缩短上传时间,方便我们调试
duration = 30 * time.Second
}
ticker := time.NewTicker(duration)
events.On(events.EventQuit, func() {
var ticker = time.NewTicker(duration)
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("TRAFFIC_STAT_MANAGER", "quit")
ticker.Stop()
})
@@ -100,8 +101,7 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
this.totalRequests++
timestamp := utils.UnixTime() / 300 * 300
timestamp := utils.FloorUnixTime(300)
key := strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
this.locker.Lock()

View File

@@ -0,0 +1,92 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package stats
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/mssola/user_agent"
"sync"
)
var SharedUserAgentParser = NewUserAgentParser()
// UserAgentParser UserAgent解析器
type UserAgentParser struct {
parser *user_agent.UserAgent
cacheMap1 map[string]UserAgentParserResult
cacheMap2 map[string]UserAgentParserResult
maxCacheItems int
cacheCursor int
locker sync.RWMutex
}
func NewUserAgentParser() *UserAgentParser {
var parser = &UserAgentParser{
parser: &user_agent.UserAgent{},
cacheMap1: map[string]UserAgentParserResult{},
cacheMap2: map[string]UserAgentParserResult{},
cacheCursor: 0,
}
parser.init()
return parser
}
func (this *UserAgentParser) init() {
var maxCacheItems = 10_000
var systemMemory = utils.SystemMemoryGB()
if systemMemory >= 16 {
maxCacheItems = 40_000
} else if systemMemory >= 8 {
maxCacheItems = 30_000
} else if systemMemory >= 4 {
maxCacheItems = 20_000
}
this.maxCacheItems = maxCacheItems
}
func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResult) {
// 限制长度
if len(userAgent) == 0 || len(userAgent) > 256 {
return
}
this.locker.RLock()
cacheResult, ok := this.cacheMap1[userAgent]
if ok {
this.locker.RUnlock()
return cacheResult
}
cacheResult, ok = this.cacheMap2[userAgent]
if ok {
this.locker.RUnlock()
return cacheResult
}
this.locker.RUnlock()
this.locker.Lock()
this.parser.Parse(userAgent)
result.OS = this.parser.OSInfo()
result.BrowserName, result.BrowserVersion = this.parser.Browser()
result.IsMobile = this.parser.Mobile()
if this.cacheCursor == 0 {
this.cacheMap1[userAgent] = result
if len(this.cacheMap1) >= this.maxCacheItems {
this.cacheCursor = 1
this.cacheMap2 = map[string]UserAgentParserResult{}
}
} else {
this.cacheMap2[userAgent] = result
if len(this.cacheMap2) >= this.maxCacheItems {
this.cacheCursor = 0
this.cacheMap1 = map[string]UserAgentParserResult{}
}
}
this.locker.Unlock()
return
}

View File

@@ -0,0 +1,12 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package stats
import "github.com/mssola/user_agent"
type UserAgentParserResult struct {
OS user_agent.OSInfo
BrowserName string
BrowserVersion string
IsMobile bool
}

View File

@@ -0,0 +1,53 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package stats
import (
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"runtime"
"runtime/debug"
"testing"
)
func TestUserAgentParser_Parse(t *testing.T) {
var parser = NewUserAgentParser()
for i := 0; i < 4; i ++ {
t.Log(parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/1"))
t.Log(parser.Parse("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36"))
}
}
func TestUserAgentParser_Parse_Unknown(t *testing.T) {
var parser = NewUserAgentParser()
t.Log(parser.Parse("Mozilla/5.0 (Wind 10.0; WOW64; rv:49.0) Apple/537.36 (KHTML, like Gecko) Chr/88.0.4324.96 Sa/537.36 Test/1"))
t.Log(parser.Parse(""))
}
func TestUserAgentParser_Memory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var parser = NewUserAgentParser()
for i := 0; i < 1_000_000; i++ {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 100_000)))
}
runtime.GC()
debug.FreeOSMemory()
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log("max cache items:", parser.maxCacheItems)
t.Log("cache1:", len(parser.cacheMap1), "cache2:", len(parser.cacheMap2), "cache3:", (stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB")
}
func BenchmarkUserAgentParser_Parse(b *testing.B) {
var parser = NewUserAgentParser()
for i := 0; i < b.N; i++ {
parser.Parse("Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36 Test/" + types.String(rands.Int(0, 40000)))
}
b.Log(len(parser.cacheMap1), len(parser.cacheMap2))
}

View File

@@ -5,6 +5,8 @@ import (
"time"
)
var SharedCache = NewCache()
// Cache TTL缓存
// 最大的缓存时间为30 * 86400
// Piece数据结构

View File

@@ -63,7 +63,11 @@ func (this *List) Remove(itemId int64) {
func (this *List) GC(timestamp int64, callback func(itemId int64)) {
this.locker.Lock()
itemMap := this.gcItems(timestamp)
var itemMap = this.gcItems(timestamp)
if len(itemMap) == 0 {
this.locker.Unlock()
return
}
this.locker.Unlock()
if callback != nil {

View File

@@ -165,3 +165,21 @@ func Benchmark_Map_Uint64(b *testing.B) {
}
}
}
func BenchmarkList_GC(b *testing.B) {
runtime.GOMAXPROCS(1)
var lists = []*List{}
for i := 0; i < 100; i++ {
lists = append(lists, NewList())
}
var timestamp = time.Now().Unix()
for i := 0; i < b.N; i++ {
for _, list := range lists {
list.GC(timestamp, nil)
}
}
}

View File

@@ -20,10 +20,13 @@ func init() {
SharedFreeHoursManager.Start()
})
})
events.On(events.EventQuit, func() {
SharedFreeHoursManager.Stop()
})
}
// FreeHoursManager 计算节点空闲时间
// 以便于我们在空闲时间执行高强度的任务,如清理缓存等
// 以便于我们在空闲时间执行高强度的任务,如清理缓存等
type FreeHoursManager struct {
dayTrafficMap map[int][24]uint64 // day => [ traffic bytes ]
lastBytes uint64
@@ -32,6 +35,7 @@ type FreeHoursManager struct {
count int
locker sync.Mutex
ticker *time.Ticker
}
func NewFreeHoursManager() *FreeHoursManager {
@@ -39,8 +43,8 @@ func NewFreeHoursManager() *FreeHoursManager {
}
func (this *FreeHoursManager) Start() {
var ticker = time.NewTicker(30 * time.Minute)
for range ticker.C {
this.ticker = time.NewTicker(30 * time.Minute)
for range this.ticker.C {
this.Update(atomic.LoadUint64(&teaconst.InTrafficBytes))
}
}
@@ -113,6 +117,12 @@ func (this *FreeHoursManager) IsFreeHour() bool {
return false
}
func (this *FreeHoursManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
// 对数组进行排序,并返回权重
func (this *FreeHoursManager) sortUintArrayWeights(arr [24]uint64) [24]uint64 {
var l = []map[string]interface{}{}

View File

@@ -0,0 +1,14 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package linkedlist
type Item struct {
prev *Item
next *Item
Value interface{}
}
func NewItem(value interface{}) *Item {
return &Item{Value: value}
}

View File

@@ -0,0 +1,93 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package linkedlist
type List struct {
head *Item
end *Item
count int
}
func NewList() *List {
return &List{}
}
func (this *List) Head() *Item {
return this.head
}
func (this *List) End() *Item {
return this.end
}
func (this *List) Push(item *Item) {
if item == nil {
return
}
// 如果已经在末尾了则do nothing
if this.end == item {
return
}
if item.prev != nil || item.next != nil || this.head == item {
this.Remove(item)
}
this.add(item)
}
func (this *List) Remove(item *Item) {
if item == nil {
return
}
if item.prev != nil {
item.prev.next = item.next
}
if item.next != nil {
item.next.prev = item.prev
}
if item == this.head {
this.head = item.next
}
if item == this.end {
this.end = item.prev
}
item.prev = nil
item.next = nil
this.count--
}
func (this *List) Len() int {
return this.count
}
func (this *List) Range(f func(item *Item) (goNext bool)) {
for e := this.head; e != nil; e = e.next {
goNext := f(e)
if !goNext {
break
}
}
}
func (this *List) Reset() {
this.head = nil
this.end = nil
}
func (this *List) add(item *Item) {
if item == nil {
return
}
if this.end != nil {
this.end.next = item
item.prev = this.end
item.next = nil
}
this.end = item
if this.head == nil {
this.head = item
}
this.count++
}

View File

@@ -0,0 +1,82 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package linkedlist_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
"runtime"
"testing"
)
func TestNewList_Memory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var list = linkedlist.NewList()
for i := 0; i < 1_000_000; i++ {
var item = &linkedlist.Item{}
list.Push(item)
}
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB")
t.Log(list.Len())
var count = 0
list.Range(func(item *linkedlist.Item) (goNext bool) {
count++
return true
})
t.Log(count)
}
func TestList_Push(t *testing.T) {
var list = linkedlist.NewList()
list.Push(linkedlist.NewItem(1))
list.Push(linkedlist.NewItem(2))
var item3 = linkedlist.NewItem(3)
list.Push(item3)
var item4 = linkedlist.NewItem(4)
list.Push(item4)
list.Range(func(item *linkedlist.Item) (goNext bool) {
t.Log(item.Value)
return true
})
t.Log("=== after push3 ===")
list.Push(item3)
list.Range(func(item *linkedlist.Item) (goNext bool) {
t.Log(item.Value)
return true
})
t.Log("=== after push4 ===")
list.Push(item4)
list.Push(item3)
list.Push(item3)
list.Push(item3)
list.Push(item4)
list.Push(item4)
list.Range(func(item *linkedlist.Item) (goNext bool) {
t.Log(item.Value)
return true
})
t.Log("=== after remove ===")
list.Remove(item3)
list.Range(func(item *linkedlist.Item) (goNext bool) {
t.Log(item.Value)
return true
})
}
func BenchmarkList_Add(b *testing.B) {
var list = linkedlist.NewList()
for i := 0; i < b.N; i++ {
var item = &linkedlist.Item{}
list.Push(item)
}
}

View File

@@ -0,0 +1,26 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package readers
import "io"
type BytesCounterReader struct {
rawReader io.Reader
count int64
}
func NewBytesCounterReader(rawReader io.Reader) *BytesCounterReader {
return &BytesCounterReader{
rawReader: rawReader,
}
}
func (this *BytesCounterReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
this.count += int64(n)
return
}
func (this *BytesCounterReader) TotalBytes() int64 {
return this.count
}

View File

@@ -0,0 +1,34 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package readers
import "io"
type FilterFunc = func(p []byte, err error) error
type FilterReader struct {
rawReader io.Reader
filters []FilterFunc
}
func NewFilterReader(rawReader io.Reader) *FilterReader {
return &FilterReader{
rawReader: rawReader,
}
}
func (this *FilterReader) Add(filter FilterFunc) {
this.filters = append(this.filters, filter)
}
func (this *FilterReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
for _, filter := range this.filters {
filterErr := filter(p[:n], err)
if filterErr != nil {
err = filterErr
return
}
}
return
}

View File

@@ -0,0 +1,41 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package readers_test
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
"testing"
)
func TestNewFilterReader(t *testing.T) {
var reader = readers.NewFilterReader(bytes.NewBufferString("0123456789"))
reader.Add(func(p []byte, err error) error {
t.Log("filter1:", string(p), err)
return nil
})
reader.Add(func(p []byte, err error) error {
t.Log("filter2:", string(p), err)
if string(p) == "345" {
return errors.New("end")
}
return nil
})
reader.Add(func(p []byte, err error) error {
t.Log("filter3:", string(p), err)
return nil
})
var buf = make([]byte, 3)
for {
n, err := reader.Read(buf)
if n > 0 {
t.Log(string(buf[:n]))
}
if err != nil {
t.Log(err)
break
}
}
}

View File

@@ -0,0 +1,52 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package readers
import (
"io"
)
type TeeReader struct {
r io.Reader
w io.Writer
onFail func(err error)
onEOF func()
}
func NewTeeReader(reader io.Reader, writer io.Writer) *TeeReader {
return &TeeReader{
r: reader,
w: writer,
}
}
func (this *TeeReader) Read(p []byte) (n int, err error) {
n, err = this.r.Read(p)
if n > 0 {
_, wErr := this.w.Write(p[:n])
if err == nil && wErr != nil {
err = wErr
}
}
if err != nil {
if err == io.EOF {
if this.onEOF != nil {
this.onEOF()
}
} else {
if this.onFail != nil {
this.onFail(err)
}
}
}
return
}
func (this *TeeReader) OnFail(onFail func(err error)) {
this.onFail = onFail
}
func (this *TeeReader) OnEOF(onEOF func()) {
this.onEOF = onEOF
}

View File

@@ -0,0 +1,58 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package readers
import "io"
type TeeReaderCloser struct {
r io.Reader
w io.Writer
onFail func(err error)
onEOF func()
}
func NewTeeReaderCloser(reader io.Reader, writer io.Writer) *TeeReaderCloser {
return &TeeReaderCloser{
r: reader,
w: writer,
}
}
func (this *TeeReaderCloser) Read(p []byte) (n int, err error) {
n, err = this.r.Read(p)
if n > 0 {
_, wErr := this.w.Write(p[:n])
if err == nil && wErr != nil {
err = wErr
}
}
if err != nil {
if err == io.EOF {
if this.onEOF != nil {
this.onEOF()
}
} else {
if this.onFail != nil {
this.onFail(err)
}
}
}
return
}
func (this *TeeReaderCloser) Close() error {
r, ok := this.r.(io.Closer)
if ok {
return r.Close()
}
return nil
}
func (this *TeeReaderCloser) OnFail(onFail func(err error)) {
this.onFail = onFail
}
func (this *TeeReaderCloser) OnEOF(onEOF func()) {
this.onEOF = onEOF
}

View File

@@ -5,17 +5,17 @@ import (
"unsafe"
)
// convert bytes to string
// UnsafeBytesToString convert bytes to string
func UnsafeBytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// convert string to bytes
// UnsafeStringToBytes convert string to bytes
func UnsafeStringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(&s))
}
// format address
// FormatAddress format address
func FormatAddress(addr string) string {
if strings.HasSuffix(addr, "unix:") {
return addr
@@ -27,7 +27,7 @@ func FormatAddress(addr string) string {
return addr
}
// format address list
// FormatAddressList format address list
func FormatAddressList(addrList []string) []string {
result := []string{}
for _, addr := range addrList {
@@ -35,3 +35,7 @@ func FormatAddressList(addrList []string) []string {
}
return result
}
func ToValidUTF8string(v string) string {
return strings.ToValidUTF8(v, "")
}

View File

@@ -23,6 +23,21 @@ func UnixTime() int64 {
return unixTime
}
// FloorUnixTime 取整
func FloorUnixTime(seconds int) int64 {
return UnixTime() / int64(seconds) * int64(seconds)
}
// CeilUnixTime 取整并加1
func CeilUnixTime(seconds int) int64 {
return UnixTime()/int64(seconds)*int64(seconds) + int64(seconds)
}
// NextMinuteUnixTime 获取下一分钟开始的时间戳
func NextMinuteUnixTime() int64 {
return CeilUnixTime(60)
}
// UnixTimeMilli 获取时间戳,精确到毫秒
func UnixTimeMilli() int64 {
return unixTimeMilli

View File

@@ -1,6 +1,7 @@
package utils
import (
timeutil "github.com/iwind/TeaGo/utils/time"
"testing"
"time"
)
@@ -19,3 +20,11 @@ func TestGMTUnixTime(t *testing.T) {
func TestGMTTime(t *testing.T) {
t.Log(GMTTime(time.Now()))
}
func TestFloorUnixTime(t *testing.T) {
var timestamp = time.Now().Unix()
t.Log("floor 60:", timestamp, FloorUnixTime(60), timeutil.FormatTime("Y-m-d H:i:s", FloorUnixTime(60)))
t.Log("ceil 60:", timestamp, CeilUnixTime(60), timeutil.FormatTime("Y-m-d H:i:s", CeilUnixTime(60)))
t.Log("floor 300:", timestamp, FloorUnixTime(300), timeutil.FormatTime("Y-m-d H:i:s", FloorUnixTime(300)))
t.Log("next minute:", NextMinuteUnixTime())
}

Some files were not shown because too many files have changed in this diff Show More