Compare commits

..

107 Commits

Author SHA1 Message Date
刘祥超
d82c03db23 修复在HTTPS下无法连接Websocket的问题 2023-01-10 21:20:27 +08:00
刘祥超
230c5c3766 版本号修改为0.6.2 2023-01-10 21:18:53 +08:00
刘祥超
927425149e 优化代码 2023-01-10 09:47:56 +08:00
刘祥超
5ce1aab92c 修复域名跳转时没有携带参数的Bug 2023-01-09 20:06:09 +08:00
刘祥超
195742bb26 修复读超时时间(ReadDeadline)导致WAFGET302、POST307延时关闭连接的问题 2023-01-09 15:56:59 +08:00
刘祥超
006cc2912d 版本修改为0.6.1 2023-01-09 15:49:16 +08:00
刘祥超
2d4ba90c3b 改进在自动读超时模式下的Websocket连接 2023-01-09 12:36:33 +08:00
刘祥超
a2e6aaaa18 WAF增加“在IP列表内”操作符/优化部分操作符代号 2023-01-08 10:15:46 +08:00
刘祥超
8e68da7725 集群服务设置增加自动读超时选项 2023-01-07 20:04:05 +08:00
刘祥超
7abb84c880 优化网络连接关闭速度 2023-01-07 10:03:32 +08:00
刘祥超
a17878f5b2 WAF增加包含任一字符串、包含所有字符串操作符 2023-01-06 20:07:15 +08:00
刘祥超
8a8881ac47 IP范围支持多行 2023-01-06 19:14:09 +08:00
刘祥超
c567404b7a 优化连接相关代码 2023-01-05 11:13:35 +08:00
刘祥超
b220b0f48e 优化读取HTTP请求Header和握手超时时间 2023-01-05 00:40:49 +08:00
刘祥超
9609c90d75 边缘节点增加数据读超时,以改进客户端上传数据过慢的问题 2023-01-04 20:43:10 +08:00
刘祥超
2c3c32af5b 优化代码 2023-01-02 10:44:10 +08:00
刘祥超
b4a4b2e9b1 集群服务设置中增加性能设置 2023-01-01 19:27:38 +08:00
刘祥超
c42ff1e1e9 实现UA名单功能 2022-12-30 20:49:43 +08:00
刘祥超
9fed1141c2 默认情况下内容压缩不支持Partial Content 2022-12-30 11:44:07 +08:00
刘祥超
e87f031293 增加CORS自适应跨域 2022-12-29 17:16:42 +08:00
刘祥超
c4bac7f43c 优化代码 2022-12-27 18:58:29 +08:00
刘祥超
47818f972e 自动转换访问域名中的大写字母 2022-12-25 15:23:56 +08:00
刘祥超
218a0300c5 修复测试用例 2022-12-23 18:53:49 +08:00
刘祥超
63f6c4177f 修复测试用例 2022-12-23 18:17:32 +08:00
刘祥超
1830c22a31 增加自动Agent识别 2022-12-22 11:38:59 +08:00
刘祥超
18611e8a7c 写数据超时时断开同客户端连接 2022-12-21 16:11:55 +08:00
刘祥超
c45f7adf04 优化连接相关代码 2022-12-21 15:59:07 +08:00
刘祥超
1a200918a8 不支持CONNECT方法 2022-12-19 16:27:58 +08:00
刘祥超
b942bb776e 国家/地区封禁、省份封禁时支持IP变量 2022-12-18 16:04:12 +08:00
刘祥超
5cf84efccd 优化内容为空的缓存 2022-12-14 15:26:18 +08:00
刘祥超
ebb6ebd10c 修复WAF中反斜杠符号(\)有可能解析错误的Bug 2022-12-14 12:27:07 +08:00
刘祥超
42d0d63cf4 优化代码 2022-12-13 18:08:50 +08:00
刘祥超
96f8f7e925 增加edge-node ip.close IP命令 2022-12-12 19:23:58 +08:00
刘祥超
e7e7214d58 调整慢连接超时算法 2022-12-12 10:04:36 +08:00
刘祥超
ade979a725 向客户端写入数据超时时立即关闭连接 2022-12-10 19:51:05 +08:00
刘祥超
60a8de13e7 TCP单次向客户端写入数据时超过30秒即认为超时 2022-12-10 18:22:00 +08:00
刘祥超
9fa24bed0a 修复WAF记录IP动作时无法不超时的Bug 2022-12-06 11:01:34 +08:00
刘祥超
87bc1a7e03 优化OpenFileCache 2022-12-05 11:16:04 +08:00
刘祥超
1a05f56149 优化缓存相关代码 2022-12-05 10:46:44 +08:00
刘祥超
f88db576e1 优化代码 2022-12-05 09:57:01 +08:00
刘祥超
dc3f26ea1a 减少WAF预读尺寸 2022-12-02 21:08:03 +08:00
刘祥超
6fc30144f7 在edge-node conns命令中显示连接时长 2022-12-02 17:03:16 +08:00
刘祥超
25b0b98bd4 增加默认的源站连接数 2022-12-02 10:39:07 +08:00
刘祥超
27b5817d5e 优化请求限制逻辑,连接关闭时自动终止内容发送 2022-11-29 19:14:46 +08:00
刘祥超
dcb61dfd33 版本号更改为0.6.0 2022-11-29 15:42:21 +08:00
刘祥超
bbcfdbbf5e 优化代码 2022-11-29 15:33:12 +08:00
刘祥超
b2a1bef08f 修复服务WAF配置无法更新的Bug 2022-11-28 18:13:08 +08:00
刘祥超
2b18b5c2ca 修改版本号为0.5.9 2022-11-28 18:08:19 +08:00
刘祥超
6ff030dbd8 编译时加入configs/cluster.template.yaml文件 2022-11-27 14:52:48 +08:00
刘祥超
0ddeef6986 支持使用域名中含有通配符清除缓存数据 2022-11-26 11:05:46 +08:00
刘祥超
976bd3600b 优化OpenFileCache功能 2022-11-25 14:52:04 +08:00
刘祥超
a64047a934 优化配置重载程序 2022-11-25 10:50:57 +08:00
刘祥超
e82f207935 统计API调用时低于一半的采样率返回总统计 2022-11-23 20:23:46 +08:00
刘祥超
61b5316a1f 优化代码 2022-11-23 20:13:34 +08:00
刘祥超
82329aa8b0 修复一处编译错误 2022-11-22 18:40:03 +08:00
刘祥超
7dabd9c19c 在监控系统运行时上报API连接状况 2022-11-22 11:23:39 +08:00
刘祥超
9437acd18c 优化代码 2022-11-21 21:08:47 +08:00
刘祥超
9da7a34edf 节点可以单独设置所使用的API节点地址 2022-11-21 19:55:28 +08:00
刘祥超
b6a5491dcc 优化Partial Content兼容性 2022-11-20 18:07:46 +08:00
刘祥超
bcee658567 优化Partial Content配置编码速度 2022-11-19 23:11:05 +08:00
刘祥超
afc8f7b703 优化Partial Content缓存 2022-11-19 21:20:53 +08:00
刘祥超
7a4b89d2fb 缓存Header中忽略Set-Cookie 2022-11-19 17:35:23 +08:00
刘祥超
c6299a2fb0 减少文件缓存写入次数 2022-11-19 17:23:45 +08:00
刘祥超
8b5d74af9b 进一步提升文件缓存写入速度 2022-11-19 15:55:05 +08:00
刘祥超
a194360a56 Update go.mod 2022-11-18 17:39:49 +08:00
刘祥超
b12f7f69ba 优化代码 2022-11-17 20:28:55 +08:00
刘祥超
06ec4d3fba 优化代码 2022-11-17 10:38:20 +08:00
刘祥超
c209ab912f 优化代码 2022-11-17 10:35:43 +08:00
刘祥超
32720d772d 优化代码 2022-11-17 10:32:26 +08:00
刘祥超
a89c02fd10 请求变量增加${cname},WAF checkpoint增加cname和isCNAME 2022-11-16 15:01:10 +08:00
刘祥超
37ef86b92f 接收HTTP请求时去除域名后面的点符号 2022-11-16 11:25:11 +08:00
刘祥超
4c19c37f49 写入缓存时减少对缓存目录的检查频率 2022-11-15 22:25:49 +08:00
刘祥超
1bb818b5b0 边缘节点支持设置多个缓存目录 2022-11-15 20:42:25 +08:00
刘祥超
825e46458f 优化代码 2022-11-15 10:06:57 +08:00
刘祥超
a42737bd28 缩短节点运行日志队列长度 2022-11-14 16:42:50 +08:00
刘祥超
5f76be2cfd 优化代码 2022-11-13 10:32:12 +08:00
刘祥超
dbddf8a91a 优化代码 2022-11-08 21:37:20 +08:00
刘祥超
6c457f41f6 优化代码 2022-11-08 20:58:17 +08:00
刘祥超
e4b2a650f0 优化代码 2022-11-08 20:19:51 +08:00
刘祥超
913ba95801 优化缓存相关代码 2022-11-08 11:03:37 +08:00
刘祥超
a9f8e39703 修复节点设置的“缓存磁盘容量”不起作用的问题 2022-11-07 21:32:20 +08:00
刘祥超
534f013f59 使用版本号来读取节点任务,提升任务同步稳定性 2022-11-06 12:07:26 +08:00
刘祥超
258380f75c 修复无法回报任务执行失败的问题 2022-11-05 14:56:57 +08:00
刘祥超
8c0e51ec46 域名跳转增加是否忽略端口选项 2022-11-05 14:30:29 +08:00
刘祥超
4c37c7ab84 时钟同步增加是否检查chrony选项 2022-11-03 14:59:26 +08:00
刘祥超
f005da1d5f 防盗链提示增加缓存时间,以提升性能 2022-11-02 15:24:30 +08:00
刘祥超
e99acc4694 版本修改为0.5.8 2022-11-02 15:11:55 +08:00
刘祥超
408357dfcf 优化DDoS防护相关错误提示信息 2022-11-01 17:37:40 +08:00
刘祥超
0109a27c06 上传访问日志发生网络错误时不提交 2022-11-01 14:55:06 +08:00
刘祥超
e6e2dccc42 版本号修改为0.5.7 2022-10-31 19:14:03 +08:00
刘祥超
09dcf0d712 集群全局服务配置中增加多个访问日志相关选项 2022-10-26 17:51:16 +08:00
刘祥超
60aebd9306 URL跳转中增加域名跳转、端口跳转 2022-10-26 16:14:37 +08:00
刘祥超
04191d04d3 节点设置中增加“通过IP名单”选项 2022-10-26 10:42:16 +08:00
刘祥超
b80a5c525f 节点缓存目录所在磁盘空间不足时(<5G),暂停缓存写入,同时启动LFU清理 2022-10-25 15:14:28 +08:00
刘祥超
265c1e5312 WAF参数定义增加优先级,可以让“轻”任务优先执行 2022-10-24 17:57:07 +08:00
刘祥超
2723f705b6 修复在iptables中加入ipv6的错误 2022-10-24 16:37:54 +08:00
刘祥超
b4cddd6341 集群服务设置--访问日志中可以设置是否只记录通用Header 2022-10-24 14:39:18 +08:00
刘祥超
5636a81d48 防盗链功能增加禁止的来源域名 2022-10-24 10:21:23 +08:00
刘祥超
d8059960de 文件缓存索引表取消UNIQUE索引,尽可能避免 sqlite malformed 错误 2022-10-23 20:45:41 +08:00
刘祥超
17af4064af 带宽和流量提交失败时,将在一定时间内重试 2022-10-23 19:41:21 +08:00
刘祥超
15f37d2c93 优化用户服务整体启用和禁用 2022-10-23 16:21:11 +08:00
刘祥超
6dc3aa8cb7 单请求写入时间从1个小时增加到2个小时 2022-10-23 09:52:50 +08:00
刘祥超
900cccf2f1 修复源站Websocket源站读取失败导致的异常错误 2022-10-18 19:43:53 +08:00
刘祥超
1fec88dfc6 优化代码 2022-10-14 15:00:05 +08:00
刘祥超
7da9363336 上传带宽信息时附带区域ID信息 2022-10-11 18:57:35 +08:00
刘祥超
d82e633bba 时钟同步程序每天只提示一次警告信息 2022-10-11 11:31:00 +08:00
刘祥超
b363bbaafd 版本修改为0.5.6 2022-10-01 08:50:12 +08:00
124 changed files with 4582 additions and 893 deletions

View File

@@ -50,6 +50,7 @@ function build() {
fi
cp "$ROOT"/configs/api.template.yaml "$DIST"/configs
cp "$ROOT"/configs/cluster.template.yaml "$DIST"/configs
cp -R "$ROOT"/www "$DIST"/
cp -R "$ROOT"/pages "$DIST"/

View File

@@ -25,7 +25,7 @@ func main() {
Product(teaconst.ProductName).
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog]").
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc]").
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove] IP")
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove|ip.close] IP")
app.On("test", func() {
err := nodes.NewNode().Test()
@@ -241,6 +241,38 @@ func main() {
}
}
})
app.On("ip.close", func() {
var args = os.Args[2:]
if len(args) == 0 {
fmt.Println("Usage: edge-node ip.close IP")
return
}
var ip = args[0]
if len(net.ParseIP(ip)) == 0 {
fmt.Println("IP '" + ip + "' is invalid")
return
}
fmt.Println("close ip '" + ip)
var sock = gosock.NewTmpSock(teaconst.ProcessName)
reply, err := sock.Send(&gosock.Command{
Code: "closeIP",
Params: map[string]any{
"ip": ip,
},
})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
var errString = maps.NewMap(reply.Params).GetString("error")
if len(errString) > 0 {
fmt.Println("[ERROR]" + errString)
} else {
fmt.Println("ok")
}
}
})
app.On("ip.remove", func() {
var args = os.Args[2:]
if len(args) == 0 {

View File

@@ -3,6 +3,7 @@
package caches
const (
SuffixAll = "@GOEDGE_" // 通用后缀
SuffixWebP = "@GOEDGE_WEBP" // WebP后缀
SuffixCompression = "@GOEDGE_" // 压缩后缀 SuffixCompression + Encoding
SuffixMethod = "@GOEDGE_" // 请求方法后缀 SuffixMethod + RequestMethod

View File

@@ -0,0 +1,11 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package caches
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
type FileDir struct {
Path string
Capacity *shared.SizeCapacity
IsFull bool
}

View File

@@ -2,6 +2,7 @@ package caches
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"strings"
"time"
)
@@ -59,3 +60,17 @@ func (this *Item) IncreaseHit(week int32) {
this.Week = week
}
}
func (this *Item) RequestURI() string {
var schemeIndex = strings.Index(this.Key, "://")
if schemeIndex <= 0 {
return ""
}
var firstSlashIndex = strings.Index(this.Key[schemeIndex+3:], "/")
if firstSlashIndex <= 0 {
return ""
}
return this.Key[schemeIndex+3+firstSlashIndex:]
}

View File

@@ -81,3 +81,14 @@ func TestItems_Memory2(t *testing.T) {
t.Log(w, len(i))
}
}
func TestItem_RequestURI(t *testing.T) {
for _, u := range []string{
"https://goedge.cn/hello/world",
"https://goedge.cn:8080/hello/world",
"https://goedge.cn/hello/world?v=1&t=123",
} {
var item = &Item{Key: u}
t.Log(u, "=>", item.RequestURI())
}
}

View File

@@ -160,6 +160,7 @@ func (this *FileList) CleanPrefix(prefix string) error {
}
defer func() {
// TODO 需要优化
this.memoryCache.Clean()
}()
@@ -172,6 +173,46 @@ func (this *FileList) CleanPrefix(prefix string) error {
return nil
}
// CleanMatchKey 清理通配符匹配的缓存数据,类似于 https://*.example.com/hello
func (this *FileList) CleanMatchKey(key string) error {
if len(key) == 0 {
return nil
}
defer func() {
// TODO 需要优化
this.memoryCache.Clean()
}()
for _, db := range this.dbList {
err := db.CleanMatchKey(key)
if err != nil {
return err
}
}
return nil
}
// CleanMatchPrefix 清理通配符匹配的缓存数据,类似于 https://*.example.com/prefix/
func (this *FileList) CleanMatchPrefix(prefix string) error {
if len(prefix) == 0 {
return nil
}
defer func() {
// TODO 需要优化
this.memoryCache.Clean()
}()
for _, db := range this.dbList {
err := db.CleanMatchPrefix(prefix)
if err != nil {
return err
}
}
return nil
}
func (this *FileList) Remove(hash string) error {
_, err := this.remove(hash)
return err

View File

@@ -13,7 +13,10 @@ import (
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"net"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"time"
@@ -108,7 +111,7 @@ func (this *FileListDB) Open(dbPath string) error {
this.writeBatch = dbs.NewBatch(writeDB, 4)
this.writeBatch.OnFail(func(err error) {
remotelogs.Warn("LIST_FILE_DB", "run batch failed: "+err.Error())
remotelogs.Warn("LIST_FILE_DB", "run batch failed: "+err.Error()+" ("+filepath.Base(this.dbPath)+")")
})
goman.New(func() {
@@ -177,6 +180,9 @@ func (this *FileListDB) Init() error {
}
this.selectHashListStmt, err = this.readDB.Prepare(`SELECT "id", "hash" FROM "` + this.itemsTableName + `" WHERE id>:id ORDER BY id ASC LIMIT 2000`)
if err != nil {
return err
}
this.deleteByHashSQL = `DELETE FROM "` + this.itemsTableName + `" WHERE "hash"=?`
this.deleteByHashStmt, err = this.writeDB.Prepare(this.deleteByHashSQL)
@@ -388,6 +394,85 @@ func (this *FileListDB) CleanPrefix(prefix string) error {
}
}
func (this *FileListDB) CleanMatchKey(key string) error {
if !this.isReady {
return nil
}
// 忽略 @GOEDGE_
if strings.Contains(key, SuffixAll) {
return nil
}
u, err := url.Parse(key)
if err != nil {
return nil
}
var host = u.Host
hostPart, _, err := net.SplitHostPort(host)
if err == nil && len(hostPart) > 0 {
host = hostPart
}
if len(host) == 0 {
return nil
}
// 转义
var queryKey = strings.ReplaceAll(key, "%", "\\%")
queryKey = strings.ReplaceAll(queryKey, "_", "\\_")
queryKey = strings.Replace(queryKey, "*", "%", 1)
// TODO 检查大批量数据下的操作性能
var staleLife = 600 // TODO 需要可以设置
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
_, err = this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET "expiredAt"=0, "staleAt"=? WHERE "host" GLOB ? AND "host" NOT GLOB ? AND "key" LIKE ? ESCAPE '\'`, unixTime+int64(staleLife), host, "*."+host, queryKey)
if err != nil {
return err
}
_, err = this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET "expiredAt"=0, "staleAt"=? WHERE "host" GLOB ? AND "host" NOT GLOB ? AND "key" LIKE ? ESCAPE '\'`, unixTime+int64(staleLife), host, "*."+host, queryKey+SuffixAll+"%")
if err != nil {
return err
}
return nil
}
func (this *FileListDB) CleanMatchPrefix(prefix string) error {
if !this.isReady {
return nil
}
u, err := url.Parse(prefix)
if err != nil {
return nil
}
var host = u.Host
hostPart, _, err := net.SplitHostPort(host)
if err == nil && len(hostPart) > 0 {
host = hostPart
}
if len(host) == 0 {
return nil
}
// 转义
var queryPrefix = strings.ReplaceAll(prefix, "%", "\\%")
queryPrefix = strings.ReplaceAll(queryPrefix, "_", "\\_")
queryPrefix = strings.Replace(queryPrefix, "*", "%", 1)
queryPrefix += "%"
// TODO 检查大批量数据下的操作性能
var staleLife = 600 // TODO 需要可以设置
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
_, err = this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET "expiredAt"=0, "staleAt"=? WHERE "host" GLOB ? AND "host" NOT GLOB ? AND "key" LIKE ? ESCAPE '\'`, unixTime+int64(staleLife), host, "*."+host, queryPrefix)
return err
}
func (this *FileListDB) CleanAll() error {
if !this.isReady {
return nil
@@ -473,6 +558,7 @@ func (this *FileListDB) initTables(times int) error {
{
// expiredAt - 过期时间,用来判断有无过期
// staleAt - 过时缓存最大时间,用来清理缓存
// 不对 hash 增加 unique 参数,是尽可能避免产生 malformed 错误
_, err := this.writeDB.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemsTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"hash" varchar(32),
@@ -498,7 +584,7 @@ ON "` + this.itemsTableName + `" (
"staleAt" ASC
);
CREATE UNIQUE INDEX IF NOT EXISTS "hash"
CREATE INDEX IF NOT EXISTS "hash"
ON "` + this.itemsTableName + `" (
"hash" ASC
);

View File

@@ -47,3 +47,41 @@ func TestFileListDB_IncreaseHitAsync(t *testing.T) {
// wait transaction
time.Sleep(1 * time.Second)
}
func TestFileListDB_CleanMatchKey(t *testing.T) {
var db = caches.NewFileListDB()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.CleanMatchKey("https://*.goedge.cn/large-text")
if err != nil {
t.Fatal(err)
}
err = db.CleanMatchKey("https://*.goedge.cn:1234/large-text?%2B____")
if err != nil {
t.Fatal(err)
}
}
func TestFileListDB_CleanMatchPrefix(t *testing.T) {
var db = caches.NewFileListDB()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.CleanMatchPrefix("https://*.goedge.cn/large-text")
if err != nil {
t.Fatal(err)
}
err = db.CleanMatchPrefix("https://*.goedge.cn:1234/large-text?%2B____")
if err != nil {
t.Fatal(err)
}
}

View File

@@ -18,6 +18,12 @@ type ListInterface interface {
// CleanPrefix 清除某个前缀的缓存
CleanPrefix(prefix string) error
// CleanMatchKey 清除通配符匹配的Key
CleanMatchKey(key string) error
// CleanMatchPrefix 清除通配符匹配的前缀
CleanMatchPrefix(prefix string) error
// Remove 删除内容
Remove(hash string) error

View File

@@ -1,8 +1,11 @@
package caches
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/logs"
"net"
"net/url"
"strconv"
"strings"
"sync"
@@ -146,6 +149,82 @@ func (this *MemoryList) CleanPrefix(prefix string) error {
return nil
}
// CleanMatchKey 清理通配符匹配的缓存数据,类似于 https://*.example.com/hello
func (this *MemoryList) CleanMatchKey(key string) error {
if strings.Contains(key, SuffixAll) {
return nil
}
u, err := url.Parse(key)
if err != nil {
return nil
}
var host = u.Host
hostPart, _, err := net.SplitHostPort(host)
if err == nil && len(hostPart) > 0 {
host = hostPart
}
if len(host) == 0 {
return nil
}
var requestURI = u.RequestURI()
this.locker.RLock()
defer this.locker.RUnlock()
// TODO 需要优化性能支持千万级数据低于1s的处理速度
for _, itemMap := range this.itemMaps {
for _, item := range itemMap {
if configutils.MatchDomain(host, item.Host) {
var itemRequestURI = item.RequestURI()
if itemRequestURI == requestURI || strings.HasPrefix(itemRequestURI, requestURI+SuffixAll) {
item.ExpiredAt = 0
}
}
}
}
return nil
}
// CleanMatchPrefix 清理通配符匹配的缓存数据,类似于 https://*.example.com/prefix/
func (this *MemoryList) CleanMatchPrefix(prefix string) error {
u, err := url.Parse(prefix)
if err != nil {
return nil
}
var host = u.Host
hostPart, _, err := net.SplitHostPort(host)
if err == nil && len(hostPart) > 0 {
host = hostPart
}
if len(host) == 0 {
return nil
}
var requestURI = u.RequestURI()
var isRootPath = requestURI == "/"
this.locker.RLock()
defer this.locker.RUnlock()
// TODO 需要优化性能支持千万级数据低于1s的处理速度
for _, itemMap := range this.itemMaps {
for _, item := range itemMap {
if configutils.MatchDomain(host, item.Host) {
var itemRequestURI = item.RequestURI()
if isRootPath || strings.HasPrefix(itemRequestURI, requestURI) {
item.ExpiredAt = 0
}
}
}
}
return nil
}
func (this *MemoryList) Remove(hash string) error {
this.locker.Lock()

View File

@@ -24,7 +24,8 @@ func init() {
type Manager struct {
// 全局配置
MaxDiskCapacity *shared.SizeCapacity
DiskDir string
MainDiskDir string
SubDiskDirs []*serverconfigs.CacheDir
MaxMemoryCapacity *shared.SizeCapacity
policyMap map[int64]*serverconfigs.HTTPCachePolicy // policyId => []*Policy
@@ -47,12 +48,10 @@ func (this *Manager) UpdatePolicies(newPolicies []*serverconfigs.HTTPCachePolicy
this.locker.Lock()
defer this.locker.Unlock()
newPolicyIds := []int64{}
var newPolicyIds = []int64{}
for _, policy := range newPolicies {
// 使用节点单独的缓存目录
if len(this.DiskDir) > 0 {
policy.UpdateDiskDir(this.DiskDir)
}
policy.UpdateDiskDir(this.MainDiskDir, this.SubDiskDirs)
newPolicyIds = append(newPolicyIds, policy.Id)
}

View File

@@ -19,7 +19,7 @@ type OpenFileCache struct {
poolList *linkedlist.List
watcher *fsnotify.Watcher
locker sync.Mutex
locker sync.RWMutex
maxSize int
count int
@@ -54,13 +54,18 @@ func NewOpenFileCache(maxSize int) (*OpenFileCache, error) {
}
func (this *OpenFileCache) Get(filename string) *OpenFile {
this.locker.Lock()
defer this.locker.Unlock()
this.locker.RLock()
pool, ok := this.poolMap[filename]
this.locker.RUnlock()
if ok {
file, consumed := pool.Get()
if consumed {
this.locker.Lock()
this.count--
// pool如果为空也不需要从列表中删除避免put时需要重新创建
this.locker.Unlock()
}
return file
}
@@ -124,6 +129,9 @@ func (this *OpenFileCache) Close(filename string) {
pool, ok := this.poolMap[filename]
if ok {
// 设置关闭状态
pool.SetClosing()
delete(this.poolMap, filename)
this.poolList.Remove(pool.linkItem)
_ = this.watcher.Remove(filename)
@@ -146,6 +154,7 @@ func (this *OpenFileCache) CloseAll() {
this.poolMap = map[string]*OpenFilePool{}
this.poolList.Reset()
_ = this.watcher.Close()
this.count = 0
this.locker.Unlock()
}

View File

@@ -0,0 +1,43 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"testing"
"time"
)
func TestNewOpenFileCache_Close(t *testing.T) {
cache, err := caches.NewOpenFileCache(1024)
if err != nil {
t.Fatal(err)
}
cache.Debug()
cache.Put("a.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("b.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("b.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("b.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("c.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Get("b.txt")
cache.Get("d.txt")
cache.Close("a.txt")
time.Sleep(100 * time.Second)
}
func TestNewOpenFileCache_CloseAll(t *testing.T) {
cache, err := caches.NewOpenFileCache(1024)
if err != nil {
t.Fatal(err)
}
cache.Debug()
cache.Put("a.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("b.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Put("c.txt", caches.NewOpenFile(nil, nil, nil, 0))
cache.Get("b.txt")
cache.Get("d.txt")
cache.CloseAll()
time.Sleep(6 * time.Second)
}

View File

@@ -12,6 +12,7 @@ type OpenFilePool struct {
linkItem *linkedlist.Item
filename string
version int64
isClosed bool
}
func NewOpenFilePool(filename string) *OpenFilePool {
@@ -29,26 +30,43 @@ func (this *OpenFilePool) Filename() string {
}
func (this *OpenFilePool) Get() (*OpenFile, bool) {
// 如果已经关闭,直接返回
if this.isClosed {
return nil, false
}
select {
case file := <-this.c:
err := file.SeekStart()
if err != nil {
_ = file.Close()
return nil, true
}
file.version = this.version
if file != nil {
err := file.SeekStart()
if err != nil {
_ = file.Close()
return nil, true
}
file.version = this.version
return file, true
return file, true
}
return nil, false
default:
return nil, false
}
}
func (this *OpenFilePool) Put(file *OpenFile) bool {
// 如果已关闭,则不接受新的文件
if this.isClosed {
_ = file.Close()
return false
}
// 检查文件版本号
if this.version > 0 && file.version > 0 && file.version != this.version {
_ = file.Close()
return false
}
// 加入Pool
select {
case this.c <- file:
return true
@@ -63,14 +81,18 @@ func (this *OpenFilePool) Len() int {
return len(this.c)
}
func (this *OpenFilePool) SetClosing() {
this.isClosed = true
}
func (this *OpenFilePool) Close() {
Loop:
this.isClosed = true
for {
select {
case file := <-this.c:
_ = file.Close()
default:
break Loop
return
}
}
}

View File

@@ -4,6 +4,8 @@ package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/rands"
"sync"
"testing"
)
@@ -15,3 +17,30 @@ func TestOpenFilePool_Get(t *testing.T) {
t.Log(pool.Get())
t.Log(pool.Get())
}
func TestOpenFilePool_Close(t *testing.T) {
var pool = caches.NewOpenFilePool("a")
pool.Put(caches.NewOpenFile(nil, nil, nil, 0))
pool.Put(caches.NewOpenFile(nil, nil, nil, 0))
pool.Close()
}
func TestOpenFilePool_Concurrent(t *testing.T) {
var pool = caches.NewOpenFilePool("a")
var concurrent = 1000
var wg = &sync.WaitGroup{}
wg.Add(concurrent)
for i := 0; i < concurrent; i++ {
go func() {
defer wg.Done()
if rands.Int(0, 1) == 1 {
pool.Put(caches.NewOpenFile(nil, nil, nil, 0))
}
if rands.Int(0, 1) == 0 {
pool.Get()
}
}()
}
wg.Wait()
}

View File

@@ -3,38 +3,88 @@
package caches
import (
"bytes"
"encoding/json"
"errors"
"github.com/iwind/TeaGo/types"
"os"
"strconv"
)
// PartialRanges 内容分区范围定义
type PartialRanges struct {
Ranges [][2]int64 `json:"ranges"`
Version int `json:"version"` // 版本号
Ranges [][2]int64 `json:"ranges"` // 范围
BodySize int64 `json:"bodySize"` // 总长度
}
// NewPartialRanges 获取新对象
func NewPartialRanges() *PartialRanges {
return &PartialRanges{Ranges: [][2]int64{}}
func NewPartialRanges(expiresAt int64) *PartialRanges {
return &PartialRanges{
Ranges: [][2]int64{},
Version: 1,
}
}
// NewPartialRangesFromData 从数据中解析范围
func NewPartialRangesFromData(data []byte) (*PartialRanges, error) {
var rs = NewPartialRanges(0)
for {
var index = bytes.IndexRune(data, '\n')
if index < 0 {
break
}
var line = data[:index]
var colonIndex = bytes.IndexRune(line, ':')
if colonIndex > 0 {
switch string(line[:colonIndex]) {
case "v": // 版本号
rs.Version = types.Int(line[colonIndex+1:])
case "b": // 总长度
rs.BodySize = types.Int64(line[colonIndex+1:])
case "r": // 范围信息
var commaIndex = bytes.IndexRune(line, ',')
if commaIndex > 0 {
rs.Ranges = append(rs.Ranges, [2]int64{types.Int64(line[colonIndex+1 : commaIndex]), types.Int64(line[commaIndex+1:])})
}
}
}
data = data[index+1:]
if len(data) == 0 {
break
}
}
return rs, nil
}
// NewPartialRangesFromJSON 从JSON中解析范围
func NewPartialRangesFromJSON(data []byte) (*PartialRanges, error) {
var rs = NewPartialRanges()
var rs = NewPartialRanges(0)
err := json.Unmarshal(data, &rs)
if err != nil {
return nil, err
}
rs.Version = 0
return rs, nil
}
// NewPartialRangesFromFile 从文件中加载范围信息
func NewPartialRangesFromFile(path string) (*PartialRanges, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return NewPartialRangesFromJSON(data)
if len(data) == 0 {
return NewPartialRanges(0), nil
}
// 兼容老的JSON格式
if data[0] == '{' {
return NewPartialRangesFromJSON(data)
}
// 新的格式
return NewPartialRangesFromData(data)
}
// Add 添加新范围
@@ -105,29 +155,27 @@ func (this *PartialRanges) Nearest(begin int64, end int64) (r [2]int64, ok bool)
return
}
// AsJSON 转换为JSON
func (this *PartialRanges) AsJSON() ([]byte, error) {
return json.Marshal(this)
// 转换为字符串
func (this *PartialRanges) String() string {
var s = "v:" + strconv.Itoa(this.Version) + "\n" + // version
"b:" + this.formatInt64(this.BodySize) + "\n" // bodySize
for _, r := range this.Ranges {
s += "r:" + this.formatInt64(r[0]) + "," + this.formatInt64(r[1]) + "\n" // range
}
return s
}
// Bytes 将内容转换为字节
func (this *PartialRanges) Bytes() []byte {
return []byte(this.String())
}
// WriteToFile 写入到文件中
func (this *PartialRanges) WriteToFile(path string) error {
data, err := this.AsJSON()
if err != nil {
return errors.New("convert to json failed: " + err.Error())
}
return os.WriteFile(path, data, 0666)
}
// ReadFromFile 从文件中读取
func (this *PartialRanges) ReadFromFile(path string) (*PartialRanges, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
return NewPartialRangesFromJSON(data)
return os.WriteFile(path, this.Bytes(), 0666)
}
// Max 获取最大位置
func (this *PartialRanges) Max() int64 {
if len(this.Ranges) > 0 {
return this.Ranges[len(this.Ranges)-1][1]
@@ -135,6 +183,11 @@ func (this *PartialRanges) Max() int64 {
return 0
}
// Reset 重置范围信息
func (this *PartialRanges) Reset() {
this.Ranges = [][2]int64{}
}
func (this *PartialRanges) merge(index int) {
// forward
var lastIndex = index
@@ -187,3 +240,7 @@ func (this *PartialRanges) max(n1 int64, n2 int64) int64 {
}
return n2
}
func (this *PartialRanges) formatInt64(i int64) string {
return strconv.FormatInt(i, 10)
}

View File

@@ -3,14 +3,16 @@
package caches_test
import (
"encoding/json"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"testing"
"time"
)
func TestNewPartialRanges(t *testing.T) {
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(1, 100)
r.Add(50, 300)
@@ -28,7 +30,7 @@ func TestNewPartialRanges(t *testing.T) {
func TestNewPartialRanges1(t *testing.T) {
var a = assert.NewAssertion(t)
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(1, 100)
r.Add(1, 101)
r.Add(1, 102)
@@ -47,7 +49,7 @@ func TestNewPartialRanges1(t *testing.T) {
func TestNewPartialRanges2(t *testing.T) {
// low -> high
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(1, 100)
r.Add(1, 101)
r.Add(1, 102)
@@ -63,7 +65,7 @@ func TestNewPartialRanges2(t *testing.T) {
func TestNewPartialRanges3(t *testing.T) {
// high -> low
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(301, 302)
r.Add(303, 304)
r.Add(200, 300)
@@ -75,7 +77,7 @@ func TestNewPartialRanges3(t *testing.T) {
func TestNewPartialRanges4(t *testing.T) {
// nearby
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(301, 302)
r.Add(303, 304)
r.Add(305, 306)
@@ -90,7 +92,7 @@ func TestNewPartialRanges4(t *testing.T) {
}
func TestNewPartialRanges5(t *testing.T) {
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
for j := 0; j < 1000; j++ {
r.Add(int64(j), int64(j+100))
}
@@ -100,7 +102,7 @@ func TestNewPartialRanges5(t *testing.T) {
func TestNewPartialRanges_Nearest(t *testing.T) {
{
// nearby
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(301, 400)
r.Add(401, 500)
r.Add(501, 600)
@@ -112,7 +114,7 @@ func TestNewPartialRanges_Nearest(t *testing.T) {
{
// nearby
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(301, 400)
r.Add(450, 500)
r.Add(550, 600)
@@ -131,45 +133,100 @@ func TestNewPartialRanges_Large_Range(t *testing.T) {
var largeSize int64 = 10000000000000
t.Log(largeSize/1024/1024/1024, "G")
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
r.Add(1, largeSize)
jsonData, err := r.AsJSON()
if err != nil {
t.Fatal(err)
}
t.Log(string(jsonData))
var s = r.String()
t.Log(s)
r2, err := caches.NewPartialRangesFromJSON(jsonData)
r2, err := caches.NewPartialRangesFromData([]byte(s))
if err != nil {
t.Fatal(err)
}
a.IsTrue(largeSize == r2.Ranges[0][1])
logs.PrintAsJSON(r, t)
}
func TestNewPartialRanges_AsJSON(t *testing.T) {
var r = caches.NewPartialRanges()
for j := 0; j < 1000; j++ {
r.Add(int64(j), int64(j+100))
func TestPartialRanges_Encode_JSON(t *testing.T) {
var r = caches.NewPartialRanges(0)
for i := 0; i < 10; i++ {
r.Ranges = append(r.Ranges, [2]int64{int64(i * 100), int64(i*100 + 100)})
}
data, err := r.AsJSON()
var before = time.Now()
data, err := json.Marshal(r)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Log(len(data))
}
r2, err := caches.NewPartialRangesFromJSON(data)
func TestPartialRanges_Encode_String(t *testing.T) {
var r = caches.NewPartialRanges(0)
r.BodySize = 1024
for i := 0; i < 10; i++ {
r.Ranges = append(r.Ranges, [2]int64{int64(i * 100), int64(i*100 + 100)})
}
var before = time.Now()
var data = r.String()
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Log(len(data))
r2, err := caches.NewPartialRangesFromData([]byte(data))
if err != nil {
t.Fatal(err)
}
t.Log(r2.Ranges)
logs.PrintAsJSON(r2, t)
}
func TestPartialRanges_Version(t *testing.T) {
{
ranges, err := caches.NewPartialRangesFromData([]byte(`e:1668928495
r:0,1048576
r:1140260864,1140295164`))
if err != nil {
t.Fatal(err)
}
t.Log("version:", ranges.Version)
}
{
ranges, err := caches.NewPartialRangesFromData([]byte(`e:1668928495
r:0,1048576
r:1140260864,1140295164
v:0
`))
if err != nil {
t.Fatal(err)
}
t.Log("version:", ranges.Version)
}
{
ranges, err := caches.NewPartialRangesFromJSON([]byte(`{}`))
if err != nil {
t.Fatal(err)
}
t.Log("version:", ranges.Version)
}
}
func BenchmarkNewPartialRanges(b *testing.B) {
for i := 0; i < b.N; i++ {
var r = caches.NewPartialRanges()
var r = caches.NewPartialRanges(0)
for j := 0; j < 1000; j++ {
r.Add(int64(j), int64(j+100))
}
}
}
func BenchmarkPartialRanges_String(b *testing.B) {
var r = caches.NewPartialRanges(0)
r.BodySize = 1024
for i := 0; i < 10; i++ {
r.Ranges = append(r.Ranges, [2]int64{int64(i * 100), int64(i*100 + 100)})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = r.String()
}
}

View File

@@ -42,7 +42,7 @@ func (this *FileReader) InitAutoDiscard(autoDiscard bool) error {
this.header = this.openFile.header
}
isOk := false
var isOk = false
if autoDiscard {
defer func() {
@@ -67,17 +67,17 @@ func (this *FileReader) InitAutoDiscard(autoDiscard bool) error {
this.expiresAt = int64(binary.BigEndian.Uint32(buf[:SizeExpiresAt]))
status := types.Int(string(buf[OffsetStatus : OffsetStatus+SizeStatus]))
var status = types.Int(string(buf[OffsetStatus : OffsetStatus+SizeStatus]))
if status < 100 || status > 999 {
return errors.New("invalid status")
}
this.status = status
// URL
urlLength := binary.BigEndian.Uint32(buf[OffsetURLLength : OffsetURLLength+SizeURLLength])
var urlLength = binary.BigEndian.Uint32(buf[OffsetURLLength : OffsetURLLength+SizeURLLength])
// header
headerSize := int(binary.BigEndian.Uint32(buf[OffsetHeaderLength : OffsetHeaderLength+SizeHeaderLength]))
var headerSize = int(binary.BigEndian.Uint32(buf[OffsetHeaderLength : OffsetHeaderLength+SizeHeaderLength]))
if headerSize == 0 {
return nil
}
@@ -86,7 +86,7 @@ func (this *FileReader) InitAutoDiscard(autoDiscard bool) error {
// body
this.bodyOffset = this.headerOffset + int64(headerSize)
bodySize := int(binary.BigEndian.Uint64(buf[OffsetBodyLength : OffsetBodyLength+SizeBodyLength]))
var bodySize = int(binary.BigEndian.Uint64(buf[OffsetBodyLength : OffsetBodyLength+SizeBodyLength]))
if bodySize == 0 {
isOk = true
return nil
@@ -158,7 +158,7 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
return nil
}
isOk := false
var isOk = false
defer func() {
if !isOk {
@@ -171,7 +171,7 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
return err
}
headerSize := this.headerSize
var headerSize = this.headerSize
for {
n, err := this.fp.Read(buf)
@@ -215,7 +215,11 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
}
func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
isOk := false
if this.bodySize == 0 {
return nil
}
var isOk = false
defer func() {
if !isOk {
@@ -257,15 +261,22 @@ func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
}
func (this *FileReader) Read(buf []byte) (n int, err error) {
if this.bodySize == 0 {
n = 0
err = io.EOF
return
}
n, err = this.fp.Read(buf)
if err != nil && err != io.EOF {
_ = this.discard()
}
return
}
func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callback ReaderFunc) error {
isOk := false
var isOk = false
defer func() {
if !isOk {
@@ -273,7 +284,7 @@ func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callba
}
}()
offset := start
var offset = start
if start < 0 {
offset = this.bodyOffset + this.bodySize + end
end = this.bodyOffset + this.bodySize - 1
@@ -296,7 +307,7 @@ func (this *FileReader) ReadBodyRange(buf []byte, start int64, end int64, callba
for {
n, err := this.fp.Read(buf)
if n > 0 {
n2 := int(end-offset) + 1
var n2 = int(end-offset) + 1
if n2 <= n {
_, e := callback(n2)
if e != nil {
@@ -344,12 +355,12 @@ func (this *FileReader) FP() *os.File {
}
func (this *FileReader) Close() error {
if this.openFileCache != nil {
if this.isClosed {
return nil
}
this.isClosed = true
if this.isClosed {
return nil
}
this.isClosed = true
if this.openFileCache != nil {
if this.openFile != nil {
this.openFileCache.Put(this.fp.Name(), this.openFile)
} else {
@@ -359,6 +370,7 @@ func (this *FileReader) Close() error {
}
return nil
}
return this.fp.Close()
}

View File

@@ -19,7 +19,7 @@ func TestFileReader(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, path := storage.keyPath("my-key")
_, path, _ := storage.keyPath("my-key")
fp, err := os.Open(path)
if err != nil {
@@ -105,7 +105,7 @@ func TestFileReader_Range(t *testing.T) {
}
_ = writer.Close()**/
_, path := storage.keyPath("my-number")
_, path, _ := storage.keyPath("my-number")
fp, err := os.Open(path)
if err != nil {

View File

@@ -117,13 +117,10 @@ func (this *PartialFileReader) ContainsRange(r rangeutils.Range) (r2 rangeutils.
r2, ok = this.ranges.Nearest(r.Start(), r.End())
if ok && this.bodySize > 0 {
// 考虑可配置
var span int64 = 512 * 1024
if this.bodySize > 1<<30 {
span = 1 << 20
}
const minSpan = 128 << 10
// 这里限制返回的最小缓存,防止因为返回的内容过小而导致请求过多
if r2.Length() < r.Length() && r2.Length() < span {
if r2.Length() < r.Length() && r2.Length() < minSpan {
ok = false
}
}
@@ -138,6 +135,10 @@ func (this *PartialFileReader) MaxLength() int64 {
return this.ranges.Max() + 1
}
func (this *PartialFileReader) Ranges() *PartialRanges {
return this.ranges
}
func (this *PartialFileReader) discard() error {
_ = os.Remove(this.rangePath)
return this.FileReader.discard()

View File

@@ -21,6 +21,7 @@ import (
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"golang.org/x/sys/unix"
"golang.org/x/text/language"
"golang.org/x/text/message"
"math"
@@ -48,8 +49,7 @@ const (
SizeBodyLength = 8
OffsetBodyLength = OffsetHeaderLength + SizeHeaderLength
SizeMeta = SizeExpiresAt + SizeStatus + SizeURLLength + SizeHeaderLength + SizeBodyLength
OffsetKey = SizeMeta
SizeMeta = SizeExpiresAt + SizeStatus + SizeURLLength + SizeHeaderLength + SizeBodyLength
)
const (
@@ -58,6 +58,7 @@ const (
HotItemLifeSeconds int64 = 3600 // 热点数据生命周期
FileToMemoryMaxSize = 32 * sizes.M // 可以从文件写入到内存的最大文件尺寸
FileTmpSuffix = ".tmp"
MinDiskSpace = 5 << 30 // 当前磁盘最小剩余空间
)
var sharedWritingFileKeyMap = map[string]zero.Zero{} // key => bool
@@ -90,6 +91,11 @@ type FileStorage struct {
ignoreKeys *setutils.FixedSet
openFileCache *OpenFileCache
mainDir string
mainDiskIsFull bool
subDirs []*FileDir
}
func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage {
@@ -153,6 +159,16 @@ func (this *FileStorage) UpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy)
return
}
var subDirs = []*FileDir{}
for _, subDir := range newOptions.SubDirs {
subDirs = append(subDirs, &FileDir{
Path: subDir.Path,
Capacity: subDir.Capacity,
IsFull: false,
})
}
this.checkDiskSpace()
err = newOptions.Init()
if err != nil {
remotelogs.Error("CACHE", "update policy '"+types.String(this.policy.Id)+"' failed: init options failed: "+err.Error())
@@ -215,6 +231,19 @@ func (this *FileStorage) Init() error {
this.options.Dir = filepath.Clean(this.options.Dir)
var dir = this.options.Dir
var subDirs = []*FileDir{}
for _, subDir := range this.options.SubDirs {
subDirs = append(subDirs, &FileDir{
Path: subDir.Path,
Capacity: subDir.Capacity,
IsFull: false,
})
}
this.subDirs = subDirs
if len(subDirs) > 0 {
this.checkDiskSpace()
}
if len(dir) == 0 {
return errors.New("[CACHE]cache storage dir can not be empty")
}
@@ -287,6 +316,9 @@ func (this *FileStorage) Init() error {
// open file cache
this.initOpenFileCache()
// 检查磁盘空间
this.checkDiskSpace()
return nil
}
@@ -314,7 +346,7 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool,
}
}
hash, path := this.keyPath(key)
hash, path, _ := this.keyPath(key)
// 检查文件记录是否已过期
if !useStale {
@@ -382,16 +414,16 @@ func (this *FileStorage) openReader(key string, allowMemory bool, useStale bool,
}
// OpenWriter 打开缓存文件等待写入
func (this *FileStorage) OpenWriter(key string, expiresAt int64, status int, size int64, maxSize int64, isPartial bool) (Writer, error) {
return this.openWriter(key, expiresAt, status, size, maxSize, isPartial, false)
func (this *FileStorage) OpenWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64, maxSize int64, isPartial bool) (Writer, error) {
return this.openWriter(key, expiresAt, status, headerSize, bodySize, maxSize, isPartial, false)
}
// OpenFlushWriter 打开从其他媒介直接刷入的写入器
func (this *FileStorage) OpenFlushWriter(key string, expiresAt int64, status int) (Writer, error) {
return this.openWriter(key, expiresAt, status, -1, -1, false, true)
func (this *FileStorage) OpenFlushWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64) (Writer, error) {
return this.openWriter(key, expiresAt, status, headerSize, bodySize, -1, false, true)
}
func (this *FileStorage) openWriter(key string, expiredAt int64, status int, size int64, maxSize int64, isPartial bool, isFlushing bool) (Writer, error) {
func (this *FileStorage) openWriter(key string, expiredAt int64, status int, headerSize int, bodySize int64, maxSize int64, isPartial bool, isFlushing bool) (Writer, error) {
// 是否正在退出
if teaconst.IsQuiting {
return nil, ErrWritingUnavailable
@@ -409,8 +441,8 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
maxMemorySize = maxSize
}
var memoryStorage = this.memoryStorage
if !isFlushing && !isPartial && memoryStorage != nil && ((size > 0 && size < maxMemorySize) || size < 0) {
writer, err := memoryStorage.OpenWriter(key, expiredAt, status, size, maxMemorySize, false)
if !isFlushing && !isPartial && memoryStorage != nil && ((bodySize > 0 && bodySize < maxMemorySize) || bodySize < 0) {
writer, err := memoryStorage.OpenWriter(key, expiredAt, status, headerSize, bodySize, maxMemorySize, false)
if err == nil {
return writer, nil
}
@@ -463,17 +495,9 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
var hash = stringutil.Md5(key)
// TODO 可以只stat一次
var dir = this.options.Dir + "/p" + strconv.FormatInt(this.policy.Id, 10) + "/" + hash[:2] + "/" + hash[2:4]
_, err = os.Stat(dir)
if err != nil {
if !os.IsNotExist(err) {
return nil, err
}
err = os.MkdirAll(dir, 0777)
if err != nil {
return nil, err
}
dir, diskIsFull := this.subDir(hash)
if diskIsFull {
return nil, NewCapacityError("the disk is full")
}
// 检查缓存是否已经生成
@@ -520,19 +544,38 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
// 从已经存储的内容中读取信息
var isNewCreated = true
var partialBodyOffset int64
var partialRanges *PartialRanges
if isPartial {
readerFp, err := os.OpenFile(tmpPath, os.O_RDONLY, 0444)
if err == nil {
var partialReader = NewPartialFileReader(readerFp)
err = partialReader.Init()
_ = partialReader.Close()
if err == nil && partialReader.bodyOffset > 0 {
isNewCreated = false
partialBodyOffset = partialReader.bodyOffset
} else {
_ = this.removeCacheFile(tmpPath)
// 数据库中是否存在
existsCacheItem, _ := this.list.Exist(hash)
if existsCacheItem {
readerFp, err := os.OpenFile(tmpPath, os.O_RDONLY, 0444)
if err == nil {
var partialReader = NewPartialFileReader(readerFp)
err = partialReader.Init()
_ = partialReader.Close()
if err == nil && partialReader.bodyOffset > 0 {
partialRanges = partialReader.Ranges()
if bodySize > 0 && partialRanges != nil && partialRanges.BodySize > 0 && bodySize != partialRanges.BodySize {
_ = this.removeCacheFile(tmpPath)
} else {
isNewCreated = false
partialBodyOffset = partialReader.bodyOffset
}
} else {
_ = this.removeCacheFile(tmpPath)
}
}
}
if isNewCreated {
err = this.list.Remove(hash)
if err != nil {
return nil, err
}
}
if partialRanges == nil {
partialRanges = NewPartialRanges(expiredAt)
}
}
var flags = os.O_CREATE | os.O_WRONLY
@@ -542,7 +585,16 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
var before = time.Now()
writer, err := os.OpenFile(tmpPath, flags, 0666)
if err != nil {
return nil, err
// TODO 检查在各个系统中的稳定性
if os.IsNotExist(err) {
_ = os.MkdirAll(dir, 0777)
// open file again
writer, err = os.OpenFile(tmpPath, flags, 0666)
}
if err != nil {
return nil, err
}
}
if !isFlushing {
if time.Since(before) >= maxOpenFilesSlowCost {
@@ -574,9 +626,12 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
return nil, ErrFileIsWriting
}
var metaBodySize int64 = -1
var metaHeaderSize = -1
if isNewCreated {
// 写入过期时间
var metaBytes = make([]byte, SizeMeta+len(key))
// 写入meta
// 从v0.5.8开始不再在meta中写入Key
var metaBytes = make([]byte, SizeMeta)
binary.BigEndian.PutUint32(metaBytes[OffsetExpiresAt:], uint32(expiredAt))
// 写入状态码
@@ -585,17 +640,17 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
}
copy(metaBytes[OffsetStatus:], strconv.Itoa(status))
// 写入URL长度
binary.BigEndian.PutUint32(metaBytes[OffsetURLLength:], uint32(len(key)))
// 写入Header Length
binary.BigEndian.PutUint32(metaBytes[OffsetHeaderLength:], uint32(0))
if headerSize > 0 {
binary.BigEndian.PutUint32(metaBytes[OffsetHeaderLength:], uint32(headerSize))
metaHeaderSize = headerSize
}
// 写入Body Length
binary.BigEndian.PutUint64(metaBytes[OffsetBodyLength:], uint64(0))
// 写入URL
copy(metaBytes[OffsetKey:], key)
if bodySize > 0 {
binary.BigEndian.PutUint64(metaBytes[OffsetBodyLength:], uint64(bodySize))
metaBodySize = bodySize
}
_, err = writer.Write(metaBytes)
if err != nil {
@@ -605,12 +660,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
isOk = true
if isPartial {
ranges, err := NewPartialRangesFromFile(cachePathName + "@ranges.cache")
if err != nil {
ranges = NewPartialRanges()
}
return NewPartialFileWriter(writer, key, expiredAt, isNewCreated, isPartial, partialBodyOffset, ranges, func() {
return NewPartialFileWriter(writer, key, expiredAt, metaHeaderSize, metaBodySize, isNewCreated, isPartial, partialBodyOffset, partialRanges, func() {
sharedWritingFileKeyLocker.Lock()
delete(sharedWritingFileKeyMap, key)
if len(sharedWritingFileKeyMap) == 0 {
@@ -619,7 +669,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
sharedWritingFileKeyLocker.Unlock()
}), nil
} else {
return NewFileWriter(this, writer, key, expiredAt, -1, func() {
return NewFileWriter(this, writer, key, expiredAt, metaHeaderSize, metaBodySize, -1, func() {
sharedWritingFileKeyLocker.Lock()
delete(sharedWritingFileKeyMap, key)
if len(sharedWritingFileKeyMap) == 0 {
@@ -646,7 +696,7 @@ func (this *FileStorage) AddToList(item *Item) {
}
item.MetaSize = SizeMeta + 128
hash := stringutil.Md5(item.Key)
var hash = stringutil.Md5(item.Key)
err := this.list.Add(hash, item)
if err != nil && !strings.Contains(err.Error(), "UNIQUE constraint failed") {
remotelogs.Error("CACHE", "add to list failed: "+err.Error())
@@ -660,15 +710,12 @@ func (this *FileStorage) Delete(key string) error {
return nil
}
this.locker.Lock()
defer this.locker.Unlock()
// 先尝试内存缓存
this.runMemoryStorageSafety(func(memoryStorage *MemoryStorage) {
_ = memoryStorage.Delete(key)
})
hash, path := this.keyPath(key)
hash, path, _ := this.keyPath(key)
err := this.list.Remove(hash)
if err != nil {
return err
@@ -683,9 +730,6 @@ func (this *FileStorage) Delete(key string) error {
// Stat 统计
func (this *FileStorage) Stat() (*Stat, error) {
this.locker.RLock()
defer this.locker.RUnlock()
return this.list.Stat(func(hash string) bool {
return true
})
@@ -708,57 +752,72 @@ func (this *FileStorage) CleanAll() error {
// 删除缓存和目录
// 不能直接删除子目录,比较危险
dir := this.dir()
fp, err := os.Open(dir)
if err != nil {
return err
}
defer func() {
_ = fp.Close()
}()
stat, err := fp.Stat()
if err != nil {
return err
}
if !stat.IsDir() {
return nil
}
// 改成待删除
subDirs, err := fp.Readdir(-1)
if err != nil {
return err
}
for _, info := range subDirs {
subDir := info.Name()
// 检查目录名
ok, err := regexp.MatchString(`^[0-9a-f]{2}$`, subDir)
if err != nil {
return err
}
if !ok {
continue
var rootDirs = []string{this.options.Dir}
var subDirs = this.subDirs // copy slice
if len(subDirs) > 0 {
for _, subDir := range subDirs {
rootDirs = append(rootDirs, subDir.Path)
}
}
// 修改目录名
tmpDir := dir + "/" + subDir + "-deleted"
err = os.Rename(dir+"/"+subDir, tmpDir)
var dirNameReg = regexp.MustCompile(`^[0-9a-f]{2}$`)
for _, rootDir := range rootDirs {
var dir = rootDir + "/p" + types.String(this.policy.Id)
err = func(dir string) error {
fp, err := os.Open(dir)
if err != nil {
return err
}
defer func() {
_ = fp.Close()
}()
stat, err := fp.Stat()
if err != nil {
return err
}
if !stat.IsDir() {
return nil
}
// 改成待删除
subDirs, err := fp.Readdir(-1)
if err != nil {
return err
}
for _, info := range subDirs {
subDir := info.Name()
// 检查目录名
if !dirNameReg.MatchString(subDir) {
continue
}
// 修改目录名
tmpDir := dir + "/" + subDir + "-deleted"
err = os.Rename(dir+"/"+subDir, tmpDir)
if err != nil {
return err
}
}
// 重新遍历待删除
goman.New(func() {
err = this.cleanDeletedDirs(dir)
if err != nil {
remotelogs.Warn("CACHE", "delete '*-deleted' dirs failed: "+err.Error())
}
})
return nil
}(dir)
if err != nil {
return err
}
}
// 重新遍历待删除
goman.New(func() {
err = this.cleanDeletedDirs(dir)
if err != nil {
remotelogs.Warn("CACHE", "delete '*-deleted' dirs failed: "+err.Error())
}
})
return nil
}
@@ -769,9 +828,6 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
return nil
}
this.locker.Lock()
defer this.locker.Unlock()
// 先尝试内存缓存
this.runMemoryStorageSafety(func(memoryStorage *MemoryStorage) {
_ = memoryStorage.Purge(keys, urlType)
@@ -780,6 +836,19 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
// 目录
if urlType == "dir" {
for _, key := range keys {
// 检查是否有通配符 http(s)://*.example.com
var schemeIndex = strings.Index(key, "://")
if schemeIndex > 0 {
var keyRight = key[schemeIndex+3:]
if strings.HasPrefix(keyRight, "*.") {
err := this.list.CleanMatchPrefix(key)
if err != nil {
return err
}
continue
}
}
err := this.list.CleanPrefix(key)
if err != nil {
return err
@@ -790,7 +859,21 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
// URL
for _, key := range keys {
hash, path := this.keyPath(key)
// 检查是否有通配符 http(s)://*.example.com
var schemeIndex = strings.Index(key, "://")
if schemeIndex > 0 {
var keyRight = key[schemeIndex+3:]
if strings.HasPrefix(keyRight, "*.") {
err := this.list.CleanMatchKey(key)
if err != nil {
return err
}
continue
}
}
// 普通的Key
hash, path, _ := this.keyPath(key)
err := this.removeCacheFile(path)
if err != nil && !os.IsNotExist(err) {
return err
@@ -861,25 +944,22 @@ func (this *FileStorage) CanSendfile() bool {
return this.options.EnableSendfile
}
// 绝对路径
func (this *FileStorage) dir() string {
return this.options.Dir + "/p" + strconv.FormatInt(this.policy.Id, 10) + "/"
}
// 获取Key对应的文件路径
func (this *FileStorage) keyPath(key string) (hash string, path string) {
func (this *FileStorage) keyPath(key string) (hash string, path string, diskIsFull bool) {
hash = stringutil.Md5(key)
dir := this.options.Dir + "/p" + strconv.FormatInt(this.policy.Id, 10) + "/" + hash[:2] + "/" + hash[2:4]
var dir string
dir, diskIsFull = this.subDir(hash)
path = dir + "/" + hash + ".cache"
return
}
// 获取Hash对应的文件路径
func (this *FileStorage) hashPath(hash string) (path string) {
func (this *FileStorage) hashPath(hash string) (path string, diskIsFull bool) {
if len(hash) != 32 {
return ""
return "", false
}
dir := this.options.Dir + "/p" + strconv.FormatInt(this.policy.Id, 10) + "/" + hash[:2] + "/" + hash[2:4]
var dir string
dir, diskIsFull = this.subDir(hash)
path = dir + "/" + hash + ".cache"
return
}
@@ -937,19 +1017,39 @@ func (this *FileStorage) initList() error {
}
// 清理任务
// TODO purge每个分区
func (this *FileStorage) purgeLoop() {
// 检查磁盘剩余空间
this.checkDiskSpace()
// 计算是否应该开启LFU清理
var capacityBytes = this.policy.CapacityBytes()
var capacityBytes = this.diskCapacityBytes()
var startLFU = false
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
var lfuFreePercent = this.policy.PersistenceLFUFreePercent
if lfuFreePercent <= 0 {
lfuFreePercent = 5
}
if capacityBytes > 0 {
if lfuFreePercent < 100 {
if usedPercent >= 100-lfuFreePercent {
startLFU = true
var hasFullDisk = this.mainDiskIsFull
if !hasFullDisk {
var subDirs = this.subDirs // copy slice
for _, subDir := range subDirs {
if subDir.IsFull {
hasFullDisk = true
break
}
}
}
if hasFullDisk {
startLFU = true
} else {
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
if capacityBytes > 0 {
if lfuFreePercent < 100 {
if usedPercent >= 100-lfuFreePercent {
startLFU = true
}
}
}
}
@@ -974,7 +1074,7 @@ func (this *FileStorage) purgeLoop() {
}
for i := 0; i < times; i++ {
countFound, err := this.list.Purge(purgeCount, func(hash string) error {
path := this.hashPath(hash)
path, _ := this.hashPath(hash)
err := this.removeCacheFile(path)
if err != nil && !os.IsNotExist(err) {
remotelogs.Error("CACHE", "purge '"+path+"' error: "+err.Error())
@@ -1008,7 +1108,7 @@ func (this *FileStorage) purgeLoop() {
remotelogs.Println("CACHE", "LFU purge policy '"+this.policy.Name+"' id: "+types.String(this.policy.Id)+", count: "+types.String(count))
err := this.list.PurgeLFU(count, func(hash string) error {
path := this.hashPath(hash)
path, _ := this.hashPath(hash)
err := this.removeCacheFile(path)
if err != nil && !os.IsNotExist(err) {
remotelogs.Error("CACHE", "purge '"+path+"' error: "+err.Error())
@@ -1089,7 +1189,7 @@ func (this *FileStorage) hotLoop() {
expiresAt = bestExpiresAt
}
writer, err := memoryStorage.openWriter(item.Key, expiresAt, reader.Status(), reader.BodySize(), -1, false)
writer, err := memoryStorage.openWriter(item.Key, expiresAt, reader.Status(), types.Int(reader.HeaderSize()), reader.BodySize(), -1, false)
if err != nil {
if !CanIgnoreErr(err) {
remotelogs.Error("CACHE", "transfer hot item failed: "+err.Error())
@@ -1113,9 +1213,12 @@ func (this *FileStorage) hotLoop() {
}
err = reader.ReadBody(buf, func(n int) (goNext bool, err error) {
_, err = writer.Write(buf[:n])
if err == nil {
goNext = true
goNext = true
if n > 0 {
_, err = writer.Write(buf[:n])
if err != nil {
goNext = false
}
}
return
})
@@ -1128,6 +1231,7 @@ func (this *FileStorage) hotLoop() {
memoryStorage.AddToList(&Item{
Type: writer.ItemType(),
Key: item.Key,
Host: ParseHost(item.Key),
ExpiredAt: expiresAt,
HeaderSize: writer.HeaderSize(),
BodySize: writer.BodySize(),
@@ -1327,3 +1431,63 @@ func (this *FileStorage) runMemoryStorageSafety(f func(memoryStorage *MemoryStor
f(memoryStorage)
}
}
// 检查磁盘剩余空间
func (this *FileStorage) checkDiskSpace() {
if this.options != nil && len(this.options.Dir) > 0 {
var stat unix.Statfs_t
err := unix.Statfs(this.options.Dir, &stat)
if err == nil {
var availableBytes = stat.Bavail * uint64(stat.Bsize)
this.mainDiskIsFull = availableBytes < MinDiskSpace
}
}
var subDirs = this.subDirs // copy slice
for _, subDir := range subDirs {
var stat unix.Statfs_t
err := unix.Statfs(subDir.Path, &stat)
if err == nil {
var availableBytes = stat.Bavail * uint64(stat.Bsize)
subDir.IsFull = availableBytes < MinDiskSpace
}
}
}
// 获取目录
func (this *FileStorage) subDir(hash string) (dirPath string, dirIsFull bool) {
var suffix = "/p" + types.String(this.policy.Id) + "/" + hash[:2] + "/" + hash[2:4]
if len(hash) < 4 {
return this.options.Dir + suffix, this.mainDiskIsFull
}
var subDirs = this.subDirs // copy slice
var countSubDirs = len(subDirs)
if countSubDirs == 0 {
return this.options.Dir + suffix, this.mainDiskIsFull
}
countSubDirs++ // add main dir
// 最多只支持16个目录
if countSubDirs > 16 {
countSubDirs = 16
}
var dirIndex = this.charCode(hash[0]) % uint8(countSubDirs)
if dirIndex == 0 {
return this.options.Dir + suffix, this.mainDiskIsFull
}
var subDir = subDirs[dirIndex-1]
return subDir.Path + suffix, subDir.IsFull
}
func (this *FileStorage) charCode(r byte) uint8 {
if r >= '0' && r <= '9' {
return r - '0'
}
if r >= 'a' && r <= 'z' {
return r - 'a' + 10
}
return 0
}

View File

@@ -62,7 +62,7 @@ func TestFileStorage_OpenWriter(t *testing.T) {
header := []byte("Header")
body := []byte("This is Body")
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200, -1, -1, false)
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
@@ -100,7 +100,7 @@ func TestFileStorage_OpenWriter_Partial(t *testing.T) {
t.Fatal(err)
}
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200, -1, -1, true)
writer, err := storage.OpenWriter("my-key", time.Now().Unix()+86400, 200, -1, -1, -1, true)
if err != nil {
t.Fatal(err)
}
@@ -139,7 +139,7 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
t.Log(time.Since(now).Seconds()*1000, "ms")
}()
writer, err := storage.OpenWriter("my-http-response", time.Now().Unix()+86400, 200, -1, -1, false)
writer, err := storage.OpenWriter("my-http-response", time.Now().Unix()+86400, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
@@ -212,7 +212,7 @@ func TestFileStorage_Concurrent_Open_DifferentFile(t *testing.T) {
go func(i int) {
defer wg.Done()
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc"+strconv.Itoa(i), time.Now().Unix()+3600, 200, -1, -1, -1, false)
if err != nil {
if err != ErrFileIsWriting {
t.Error(err)
@@ -267,7 +267,7 @@ func TestFileStorage_Concurrent_Open_SameFile(t *testing.T) {
go func(i int) {
defer wg.Done()
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc"+strconv.Itoa(0), time.Now().Unix()+3600, 200, -1, -1, -1, false)
if err != nil {
if err != ErrFileIsWriting {
t.Error(err)
@@ -522,7 +522,7 @@ func TestFileStorage_DecodeFile(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, path := storage.keyPath("my-key")
_, path, _ := storage.keyPath("my-key")
t.Log(path)
}
@@ -569,6 +569,6 @@ func BenchmarkFileStorage_KeyPath(b *testing.B) {
}
for i := 0; i < b.N; i++ {
_, _ = storage.keyPath(strconv.Itoa(i))
_, _, _ = storage.keyPath(strconv.Itoa(i))
}
}

View File

@@ -14,10 +14,10 @@ type StorageInterface interface {
// OpenWriter 打开缓存写入器等待写入
// size 和 maxSize 可能为-1
OpenWriter(key string, expiresAt int64, status int, size int64, maxSize int64, isPartial bool) (Writer, error)
OpenWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64, maxSize int64, isPartial bool) (Writer, error)
// OpenFlushWriter 打开从其他媒介直接刷入的写入器
OpenFlushWriter(key string, expiresAt int64, status int) (Writer, error)
OpenFlushWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64) (Writer, error)
// Delete 删除某个键值对应的缓存
Delete(key string) error

View File

@@ -1,7 +1,6 @@
package caches
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -17,6 +16,7 @@ import (
"math"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -149,7 +149,7 @@ func (this *MemoryStorage) OpenReader(key string, useStale bool, isPartial bool)
}
// OpenWriter 打开缓存写入器等待写入
func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, size int64, maxSize int64, isPartial bool) (Writer, error) {
func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, headerSize int, bodySize int64, maxSize int64, isPartial bool) (Writer, error) {
if this.ignoreKeys.Has(key) {
return nil, ErrEntityTooLarge
}
@@ -158,15 +158,15 @@ func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int, s
if isPartial {
return nil, ErrFileIsWriting
}
return this.openWriter(key, expiredAt, status, size, maxSize, true)
return this.openWriter(key, expiredAt, status, headerSize, bodySize, maxSize, true)
}
// OpenFlushWriter 打开从其他媒介直接刷入的写入器
func (this *MemoryStorage) OpenFlushWriter(key string, expiresAt int64, status int) (Writer, error) {
return this.openWriter(key, expiresAt, status, -1, -1, true)
func (this *MemoryStorage) OpenFlushWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64) (Writer, error) {
return this.openWriter(key, expiresAt, status, headerSize, bodySize, -1, true)
}
func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, size int64, maxSize int64, isDirty bool) (Writer, error) {
func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, headerSize int, bodySize int64, maxSize int64, isDirty bool) (Writer, error) {
// 待写入队列是否已满
if isDirty &&
this.parentStorage != nil &&
@@ -207,10 +207,10 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, s
return nil, NewCapacityError("write memory cache failed: too many keys in cache storage")
}
capacityBytes := this.memoryCapacityBytes()
if size < 0 {
size = 0
if bodySize < 0 {
bodySize = 0
}
if capacityBytes > 0 && capacityBytes <= this.totalSize+size {
if capacityBytes > 0 && capacityBytes <= this.totalSize+bodySize {
return nil, NewCapacityError("write memory cache failed: over memory size: " + strconv.FormatInt(capacityBytes, 10) + ", current size: " + strconv.FormatInt(this.totalSize, 10) + " bytes")
}
@@ -230,10 +230,10 @@ func (this *MemoryStorage) openWriter(key string, expiresAt int64, status int, s
// Delete 删除某个键值对应的缓存
func (this *MemoryStorage) Delete(key string) error {
hash := this.hash(key)
var hash = this.hash(key)
this.locker.Lock()
delete(this.valuesMap, hash)
_ = this.list.Remove(fmt.Sprintf("%d", hash))
_ = this.list.Remove(types.String(hash))
this.locker.Unlock()
return nil
}
@@ -263,6 +263,19 @@ func (this *MemoryStorage) Purge(keys []string, urlType string) error {
// 目录
if urlType == "dir" {
for _, key := range keys {
// 检查是否有通配符 http(s)://*.example.com
var schemeIndex = strings.Index(key, "://")
if schemeIndex > 0 {
var keyRight = key[schemeIndex+3:]
if strings.HasPrefix(keyRight, "*.") {
err := this.list.CleanMatchPrefix(key)
if err != nil {
return err
}
continue
}
}
err := this.list.CleanPrefix(key)
if err != nil {
return err
@@ -273,6 +286,19 @@ func (this *MemoryStorage) Purge(keys []string, urlType string) error {
// URL
for _, key := range keys {
// 检查是否有通配符 http(s)://*.example.com
var schemeIndex = strings.Index(key, "://")
if schemeIndex > 0 {
var keyRight = key[schemeIndex+3:]
if strings.HasPrefix(keyRight, "*.") {
err := this.list.CleanMatchKey(key)
if err != nil {
return err
}
continue
}
}
err := this.Delete(key)
if err != nil {
return err
@@ -336,7 +362,12 @@ func (this *MemoryStorage) CanUpdatePolicy(newPolicy *serverconfigs.HTTPCachePol
// AddToList 将缓存添加到列表
func (this *MemoryStorage) AddToList(item *Item) {
item.MetaSize = int64(len(item.Key)) + 128 /** 128是我们评估的数据结构的长度 **/
hash := fmt.Sprintf("%d", this.hash(item.Key))
var hash = types.String(this.hash(item.Key))
if len(item.Host) == 0 {
item.Host = ParseHost(item.Key)
}
_ = this.list.Add(hash, item)
}
@@ -433,7 +464,7 @@ func (this *MemoryStorage) startFlush() {
var statCount = 0
var writeDelayMS float64 = 0
for hash := range this.dirtyChan {
for key := range this.dirtyChan {
statCount++
if statCount == 100 {
@@ -455,7 +486,7 @@ func (this *MemoryStorage) startFlush() {
}
}
this.flushItem(hash)
this.flushItem(key)
if writeDelayMS > 0 {
time.Sleep(time.Duration(writeDelayMS) * time.Millisecond)
@@ -481,7 +512,7 @@ func (this *MemoryStorage) flushItem(key string) {
return
}
writer, err := this.parentStorage.OpenFlushWriter(key, item.ExpiresAt, item.Status)
writer, err := this.parentStorage.OpenFlushWriter(key, item.ExpiresAt, item.Status, len(item.HeaderValue), int64(len(item.BodyValue)))
if err != nil {
if !CanIgnoreErr(err) {
remotelogs.Error("CACHE", "flush items failed: open writer failed: "+err.Error())
@@ -513,6 +544,7 @@ func (this *MemoryStorage) flushItem(key string) {
this.parentStorage.AddToList(&Item{
Type: writer.ItemType(),
Key: key,
Host: ParseHost(key),
ExpiredAt: item.ExpiresAt,
HeaderSize: writer.HeaderSize(),
BodySize: writer.BodySize(),
@@ -542,7 +574,7 @@ func (this *MemoryStorage) memoryCapacityBytes() int64 {
func (this *MemoryStorage) deleteWithoutLocker(key string) error {
hash := this.hash(key)
delete(this.valuesMap, hash)
_ = this.list.Remove(fmt.Sprintf("%d", hash))
_ = this.list.Remove(types.String(hash))
return nil
}

View File

@@ -14,15 +14,22 @@ import (
)
func TestMemoryStorage_OpenWriter(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}
_, _ = writer.WriteHeader([]byte("Header"))
_, _ = writer.Write([]byte("Hello"))
_, _ = writer.Write([]byte(", World"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log(storage.valuesMap)
{
@@ -30,6 +37,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
if err != nil {
if err == ErrNotFound {
t.Log("not found: abc")
return
} else {
t.Fatal(err)
}
@@ -63,7 +71,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
}
}
writer, err = storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, false)
writer, err = storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
@@ -102,21 +110,29 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) {
}
func TestMemoryStorage_Delete(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
{
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log(len(storage.valuesMap))
}
{
writer, err := storage.OpenWriter("abc1", time.Now().Unix()+60, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc1", time.Now().Unix()+60, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log(len(storage.valuesMap))
}
_ = storage.Delete("abc1")
@@ -124,14 +140,18 @@ func TestMemoryStorage_Delete(t *testing.T) {
}
func TestMemoryStorage_Stat(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log(len(storage.valuesMap))
storage.AddToList(&Item{
Key: "abc",
@@ -140,11 +160,15 @@ func TestMemoryStorage_Stat(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
t.Log(len(storage.valuesMap))
storage.AddToList(&Item{
Key: "abc1",
@@ -161,14 +185,18 @@ func TestMemoryStorage_Stat(t *testing.T) {
}
func TestMemoryStorage_CleanAll(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
var expiredAt = time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
storage.AddToList(&Item{
Key: "abc",
BodySize: 5,
@@ -176,11 +204,15 @@ func TestMemoryStorage_CleanAll(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
storage.AddToList(&Item{
Key: "abc1",
BodySize: 5,
@@ -199,11 +231,15 @@ func TestMemoryStorage_Purge(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
expiredAt := time.Now().Unix() + 60
{
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
storage.AddToList(&Item{
Key: "abc",
BodySize: 5,
@@ -211,11 +247,15 @@ func TestMemoryStorage_Purge(t *testing.T) {
})
}
{
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter("abc1", expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
storage.AddToList(&Item{
Key: "abc1",
BodySize: 5,
@@ -231,7 +271,7 @@ func TestMemoryStorage_Purge(t *testing.T) {
}
func TestMemoryStorage_Expire(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
MemoryAutoPurgeInterval: 5,
}, nil)
err := storage.Init()
@@ -242,11 +282,15 @@ func TestMemoryStorage_Expire(t *testing.T) {
for i := 0; i < 1000; i++ {
expiredAt := time.Now().Unix() + int64(rands.Int(0, 60))
key := "abc" + strconv.Itoa(i)
writer, err := storage.OpenWriter(key, expiredAt, 200, -1, -1, false)
writer, err := storage.OpenWriter(key, expiredAt, 200, -1, -1, -1, false)
if err != nil {
t.Fatal(err)
}
_, _ = writer.Write([]byte("Hello"))
err = writer.Close()
if err != nil {
t.Fatal(err)
}
storage.AddToList(&Item{
Key: key,
BodySize: 5,
@@ -257,7 +301,7 @@ func TestMemoryStorage_Expire(t *testing.T) {
}
func TestMemoryStorage_Locker(t *testing.T) {
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
err := storage.Init()
if err != nil {
t.Fatal(err)

30
internal/caches/utils.go Normal file
View File

@@ -0,0 +1,30 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package caches
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"net"
"strings"
)
func ParseHost(key string) string {
var schemeIndex = strings.Index(key, "://")
if schemeIndex <= 0 {
return ""
}
var firstSlashIndex = strings.Index(key[schemeIndex+3:], "/")
if firstSlashIndex <= 0 {
return ""
}
var host = key[schemeIndex+3 : schemeIndex+3+firstSlashIndex]
hostPart, _, err := net.SplitHostPort(host)
if err == nil && len(hostPart) > 0 {
host = configutils.QuoteIP(hostPart)
}
return host
}

View File

@@ -0,0 +1,51 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package caches_test
import (
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/cespare/xxhash"
"github.com/iwind/TeaGo/types"
"strconv"
"testing"
)
func TestParseHost(t *testing.T) {
for _, u := range []string{
"https://goedge.cn/hello/world",
"https://goedge.cn:8080/hello/world",
"https://goedge.cn/hello/world?v=1&t=123",
"https://[::1]:1234/hello/world?v=1&t=123",
"https://[::1]/hello/world?v=1&t=123",
"https://127.0.0.1/hello/world?v=1&t=123",
"https:/hello/world?v=1&t=123",
"123456",
} {
t.Log(u, "=>", caches.ParseHost(u))
}
}
func TestUintString(t *testing.T) {
t.Log(strconv.FormatUint(xxhash.Sum64String("https://goedge.cn/"), 10))
t.Log(strconv.FormatUint(123456789, 10))
t.Log(fmt.Sprintf("%d", 1234567890123))
}
func BenchmarkUint_String(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = strconv.FormatUint(1234567890123, 10)
}
}
func BenchmarkUint_String2(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = types.String(1234567890123)
}
}
func BenchmarkUint_String3(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = fmt.Sprintf("%d", 1234567890123)
}
}

View File

@@ -11,25 +11,32 @@ import (
)
type FileWriter struct {
storage StorageInterface
rawWriter *os.File
key string
headerSize int64
bodySize int64
expiredAt int64
maxSize int64
endFunc func()
once sync.Once
storage StorageInterface
rawWriter *os.File
key string
metaHeaderSize int
headerSize int64
metaBodySize int64 // 写入前的内容长度
bodySize int64
expiredAt int64
maxSize int64
endFunc func()
once sync.Once
}
func NewFileWriter(storage StorageInterface, rawWriter *os.File, key string, expiredAt int64, maxSize int64, endFunc func()) *FileWriter {
func NewFileWriter(storage StorageInterface, rawWriter *os.File, key string, expiredAt int64, metaHeaderSize int, metaBodySize int64, maxSize int64, endFunc func()) *FileWriter {
return &FileWriter{
storage: storage,
key: key,
rawWriter: rawWriter,
expiredAt: expiredAt,
maxSize: maxSize,
endFunc: endFunc,
storage: storage,
key: key,
rawWriter: rawWriter,
expiredAt: expiredAt,
maxSize: maxSize,
endFunc: endFunc,
metaHeaderSize: metaHeaderSize,
metaBodySize: metaBodySize,
}
}
@@ -45,7 +52,10 @@ func (this *FileWriter) WriteHeader(data []byte) (n int, err error) {
// WriteHeaderLength 写入Header长度数据
func (this *FileWriter) WriteHeaderLength(headerLength int) error {
bytes4 := make([]byte, 4)
if this.metaHeaderSize > 0 && this.metaHeaderSize == headerLength {
return nil
}
var bytes4 = make([]byte, 4)
binary.BigEndian.PutUint32(bytes4, uint32(headerLength))
_, err := this.rawWriter.Seek(SizeExpiresAt+SizeStatus+SizeURLLength, io.SeekStart)
if err != nil {
@@ -88,7 +98,10 @@ func (this *FileWriter) WriteAt(offset int64, data []byte) error {
// WriteBodyLength 写入Body长度数据
func (this *FileWriter) WriteBodyLength(bodyLength int64) error {
bytes8 := make([]byte, 8)
if this.metaBodySize >= 0 && bodyLength == this.metaBodySize {
return nil
}
var bytes8 = make([]byte, 8)
binary.BigEndian.PutUint64(bytes8, uint64(bodyLength))
_, err := this.rawWriter.Seek(SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength, io.SeekStart)
if err != nil {
@@ -109,7 +122,7 @@ func (this *FileWriter) Close() error {
this.endFunc()
})
path := this.rawWriter.Name()
var path = this.rawWriter.Name()
err := this.WriteHeaderLength(types.Int(this.headerSize))
if err != nil {

View File

@@ -11,13 +11,18 @@ import (
)
type PartialFileWriter struct {
rawWriter *os.File
key string
headerSize int64
bodySize int64
expiredAt int64
endFunc func()
once sync.Once
rawWriter *os.File
key string
metaHeaderSize int
headerSize int64
metaBodySize int64
bodySize int64
expiredAt int64
endFunc func()
once sync.Once
isNew bool
isPartial bool
@@ -27,17 +32,19 @@ type PartialFileWriter struct {
rangePath string
}
func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, isNew bool, isPartial bool, bodyOffset int64, ranges *PartialRanges, endFunc func()) *PartialFileWriter {
func NewPartialFileWriter(rawWriter *os.File, key string, expiredAt int64, metaHeaderSize int, metaBodySize int64, isNew bool, isPartial bool, bodyOffset int64, ranges *PartialRanges, endFunc func()) *PartialFileWriter {
return &PartialFileWriter{
key: key,
rawWriter: rawWriter,
expiredAt: expiredAt,
endFunc: endFunc,
isNew: isNew,
isPartial: isPartial,
bodyOffset: bodyOffset,
ranges: ranges,
rangePath: partialRangesFilePath(rawWriter.Name()),
key: key,
rawWriter: rawWriter,
expiredAt: expiredAt,
endFunc: endFunc,
isNew: isNew,
isPartial: isPartial,
bodyOffset: bodyOffset,
ranges: ranges,
rangePath: partialRangesFilePath(rawWriter.Name()),
metaHeaderSize: metaHeaderSize,
metaBodySize: metaBodySize,
}
}
@@ -71,7 +78,11 @@ func (this *PartialFileWriter) AppendHeader(data []byte) error {
// WriteHeaderLength 写入Header长度数据
func (this *PartialFileWriter) WriteHeaderLength(headerLength int) error {
bytes4 := make([]byte, 4)
if this.metaHeaderSize > 0 && this.metaHeaderSize == headerLength {
return nil
}
var bytes4 = make([]byte, 4)
binary.BigEndian.PutUint32(bytes4, uint32(headerLength))
_, err := this.rawWriter.Seek(SizeExpiresAt+SizeStatus+SizeURLLength, io.SeekStart)
if err != nil {
@@ -110,8 +121,13 @@ func (this *PartialFileWriter) WriteAt(offset int64, data []byte) error {
}
if this.bodyOffset == 0 {
this.bodyOffset = SizeMeta + int64(len(this.key)) + this.headerSize
var keyLength = 0
if this.ranges.Version == 0 { // 以往的版本包含有Key
keyLength = len(this.key)
}
this.bodyOffset = SizeMeta + int64(keyLength) + this.headerSize
}
_, err := this.rawWriter.WriteAt(data, this.bodyOffset+offset)
if err != nil {
return err
@@ -129,7 +145,10 @@ func (this *PartialFileWriter) SetBodyLength(bodyLength int64) {
// WriteBodyLength 写入Body长度数据
func (this *PartialFileWriter) WriteBodyLength(bodyLength int64) error {
bytes8 := make([]byte, 8)
if this.metaBodySize > 0 && this.metaBodySize == bodyLength {
return nil
}
var bytes8 = make([]byte, 8)
binary.BigEndian.PutUint64(bytes8, uint64(bodyLength))
_, err := this.rawWriter.Seek(SizeExpiresAt+SizeStatus+SizeURLLength+SizeHeaderLength, io.SeekStart)
if err != nil {
@@ -150,8 +169,11 @@ func (this *PartialFileWriter) Close() error {
this.endFunc()
})
this.ranges.BodySize = this.bodySize
err := this.ranges.WriteToFile(this.rangePath)
if err != nil {
_ = this.rawWriter.Close()
this.remove()
return err
}

View File

@@ -26,8 +26,8 @@ func TestPartialFileWriter_Write(t *testing.T) {
if err != nil {
t.Fatal(err)
}
var ranges = caches.NewPartialRanges()
var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, true, true, 0, ranges, func() {
var ranges = caches.NewPartialRanges(0)
var writer = caches.NewPartialFileWriter(fp, "test", time.Now().Unix()+86500, -1, -1, true, true, 0, ranges, func() {
t.Log("end")
})
_, err = writer.WriteHeader([]byte("header"))

View File

@@ -9,11 +9,15 @@ import (
// APIConfig 节点API配置
type APIConfig struct {
RPC struct {
Endpoints []string `yaml:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate"`
} `yaml:"rpc"`
NodeId string `yaml:"nodeId"`
Secret string `yaml:"secret"`
Endpoints []string `yaml:"endpoints" json:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate" json:"disableUpdate"`
} `yaml:"rpc" json:"rpc"`
NodeId string `yaml:"nodeId" json:"nodeId"`
Secret string `yaml:"secret" json:"secret"`
}
func NewAPIConfig() *APIConfig {
return &APIConfig{}
}
func LoadAPIConfig() (*APIConfig, error) {

View File

@@ -3,9 +3,9 @@ package configs
// ClusterConfig 集群配置
type ClusterConfig struct {
RPC struct {
Endpoints []string `yaml:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate"`
} `yaml:"rpc"`
ClusterId string `yaml:"clusterId"`
Secret string `yaml:"secret"`
Endpoints []string `yaml:"endpoints" json:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate" json:"disableUpdate"`
} `yaml:"rpc" json:"rpc"`
ClusterId string `yaml:"clusterId" json:"clusterId"`
Secret string `yaml:"secret" json:"secret"`
}

7
internal/conns/linger.go Normal file
View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package conns
type LingerConn interface {
SetLinger(sec int) error
}

View File

@@ -10,7 +10,7 @@ import (
var SharedMap = NewMap()
type Map struct {
m map[string]map[int]net.Conn // ip => { port => Conn }
m map[string]map[int]net.Conn // ip => { port => Conn }
locker sync.RWMutex
}
@@ -37,9 +37,7 @@ func (this *Map) Add(conn net.Conn) {
defer this.locker.Unlock()
connMap, ok := this.m[ip]
if !ok {
this.m[ip] = map[int]net.Conn{
port: conn,
}
this.m[ip] = map[int]net.Conn{port: conn}
} else {
connMap[port] = conn
}
@@ -96,6 +94,13 @@ func (this *Map) CloseIPConns(ip string) {
if ok {
for _, conn := range conns {
// 设置Linger
lingerConn, isLingerConn := conn.(LingerConn)
if isLingerConn {
_ = lingerConn.SetLinger(0)
}
// 关闭
_ = conn.Close()
}
@@ -109,9 +114,10 @@ func (this *Map) AllConns() []net.Conn {
var result = []net.Conn{}
for _, m := range this.m {
for _, conn := range m {
result = append(result, conn)
for _, connInfo := range m {
result = append(result, connInfo)
}
}
return result
}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.5.5"
Version = "0.6.2"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -70,7 +70,7 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
var allowIPList = nodeConfig.AllowedIPs
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
if !utils.EqualStrings(allowIPList, this.lastAllowIPList) {
allowIPListChanged = true
this.lastAllowIPList = allowIPList
}
@@ -91,6 +91,9 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
}
if nftablesInstance == nil {
if config == nil || !config.IsOn() {
return nil
}
return errors.New("nftables instance should not be nil")
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"os/exec"
"runtime"
@@ -74,10 +75,16 @@ func (this *IPTablesAction) runAction(action string, listType IPListType, item *
}
func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
// 暂时不支持ipv6
// TODO 将来支持ipv6
if utils.IsIPv6(item.IpFrom) {
return nil
}
if item.Type == "all" {
return nil
}
path := this.config.Path
var path = this.config.Path
var err error
if len(path) == 0 {
path, err = exec.LookPath("iptables")
@@ -88,6 +95,7 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
this.iptablesNotFound = true
return err
}
this.config.Path = path
}
iptablesAction := ""
switch action {

View File

@@ -1,6 +1,7 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
@@ -141,6 +142,12 @@ func (this *IPListManager) init() {
}
func (this *IPListManager) loop() error {
// 是否同步IP名单
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && !nodeConfig.EnableIPLists {
return nil
}
for {
hasNext, err := this.fetch()
if err != nil {

View File

@@ -178,7 +178,7 @@ func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error {
}
expiredAt := time.Now().Unix() + msg.LifeSeconds
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, int64(len(msg.Value)), -1, false)
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, -1, int64(len(msg.Value)), -1, false)
if err != nil {
this.replyFail(message.RequestId, "prepare writing failed: "+err.Error())
return err
@@ -407,7 +407,7 @@ func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) e
var protectionConfig = sharedNodeConfig.DDoSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}

View File

@@ -3,6 +3,7 @@
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
@@ -25,17 +26,30 @@ import (
type ClientConn struct {
BaseClientConn
isTLS bool
hasDeadline bool
hasRead bool
createdAt int64
isTLS bool
isHTTP bool
hasRead bool
isLO bool // 是否为环路
isInAllowList bool
hasResetSYNFlood bool
lastReadAt int64
lastWriteAt int64
lastErr error
readDeadlineTime int64
isShortReading bool // reading header or tls handshake
isDebugging bool
autoReadTimeout bool
autoWriteTimeout bool
}
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool) net.Conn {
// 是否为环路
var remoteAddr = rawConn.RemoteAddr().String()
var isLO = strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:")
@@ -43,11 +57,21 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
var conn = &ClientConn{
BaseClientConn: BaseClientConn{rawConn: rawConn},
isTLS: isTLS,
isHTTP: isHTTP,
isLO: isLO,
isInAllowList: isInAllowList,
createdAt: time.Now().Unix(),
}
if quickClose {
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil {
var performanceConfig = globalServerConfig.Performance
conn.isDebugging = performanceConfig.Debug
conn.autoReadTimeout = performanceConfig.AutoReadTimeout
conn.autoWriteTimeout = performanceConfig.AutoWriteTimeout
}
if isHTTP {
// TODO 可以在配置中设置此值
_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
}
@@ -59,6 +83,16 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
if this.isDebugging {
this.lastReadAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("read error: " + err.Error())
}
}()
}
// 环路直接读取
if this.isLO {
n, err = this.rawConn.Read(b)
@@ -68,34 +102,29 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
return
}
// TLS
if this.isTLS {
if !this.hasDeadline {
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
this.hasDeadline = true
defer func() {
_ = this.rawConn.SetReadDeadline(time.Time{})
}()
}
// 设置读超时时间
if this.isHTTP && !this.isWebsocket && !this.isShortReading && this.autoReadTimeout {
this.setHTTPReadTimeout()
}
// 开始读取
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
if !this.hasRead {
this.hasRead = true
}
this.hasRead = true
}
// 检测是否为握手错误
var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
if isHandshakeError {
// 检测是否为超时错误
var isTimeout = err != nil && os.IsTimeout(err)
var isHandshakeError = isTimeout && !this.hasRead
if isTimeout {
_ = this.SetLinger(0)
} else {
_ = this.SetLinger(nodeconfigs.DefaultTCPLinger)
}
// 忽略白名单和局域网
if !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
if this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
@@ -114,6 +143,32 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
}
func (this *ClientConn) Write(b []byte) (n int, err error) {
if this.isDebugging {
this.lastWriteAt = time.Now().Unix()
defer func() {
if err != nil {
this.lastErr = errors.New("write error: " + err.Error())
}
}()
}
// 设置写超时时间
if this.autoWriteTimeout {
// TODO L2 -> L1 写入时不限制时间
var timeoutSeconds = len(b) / 1024
if timeoutSeconds < 3 {
timeoutSeconds = 3
}
_ = this.rawConn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutSeconds) * time.Second)) // TODO 时间可以设置
}
// 延长读超时时间
if this.isHTTP && !this.isWebsocket && this.autoReadTimeout {
this.setHTTPReadTimeout()
}
// 开始写入
n, err = this.rawConn.Write(b)
if n > 0 {
// 统计当前服务带宽
@@ -125,6 +180,17 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
}
}
// 如果是写入超时,则立即关闭连接
if err != nil && os.IsTimeout(err) {
// TODO 考虑对多次慢连接的IP做出惩罚
conn, ok := this.rawConn.(LingerConn)
if ok {
_ = conn.SetLinger(0)
}
_ = this.Close()
}
return
}
@@ -156,6 +222,26 @@ func (this *ClientConn) SetDeadline(t time.Time) error {
}
func (this *ClientConn) SetReadDeadline(t time.Time) error {
// 如果开启了HTTP自动读超时选项则自动控制超时时间
if this.isHTTP && !this.isWebsocket && this.autoReadTimeout {
this.isShortReading = false
var unixTime = t.Unix()
if unixTime < 10 {
return nil
}
if unixTime == this.readDeadlineTime {
return nil
}
this.readDeadlineTime = unixTime
var seconds = -time.Since(t)
if seconds <= 0 || seconds > HTTPIdleTimeout {
return nil
}
if seconds < HTTPIdleTimeout-1*time.Second {
this.isShortReading = true
}
}
return this.rawConn.SetReadDeadline(t)
}
@@ -163,6 +249,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientConn) CreatedAt() int64 {
return this.createdAt
}
func (this *ClientConn) LastReadAt() int64 {
return this.lastReadAt
}
func (this *ClientConn) LastWriteAt() int64 {
return this.lastWriteAt
}
func (this *ClientConn) LastErr() error {
return this.lastErr
}
func (this *ClientConn) resetSYNFlood() {
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
}
@@ -194,3 +296,8 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
}
}
}
// 设置读超时时间
func (this *ClientConn) setHTTPReadTimeout() {
_ = this.SetReadDeadline(time.Now().Add(HTTPIdleTimeout))
}

View File

@@ -16,6 +16,8 @@ type BaseClientConn struct {
remoteAddr string
hasLimit bool
isWebsocket bool
isClosed bool
rawIP string
@@ -122,3 +124,7 @@ func (this *BaseClientConn) SetLinger(seconds int) error {
}
return nil
}
func (this *BaseClientConn) SetIsWebsocket(isWebsocket bool) {
this.isWebsocket = isWebsocket
}

View File

@@ -23,4 +23,7 @@ type ClientConnInterface interface {
// UserId 获取当前连接所属服务的用户ID
UserId() int64
// SetIsWebsocket 设置是否为Websocket
SetIsWebsocket(isWebsocket bool)
}

View File

@@ -14,14 +14,14 @@ import (
// ClientListener 客户端网络监听
type ClientListener struct {
rawListener net.Listener
isHTTP bool
isTLS bool
quickClose bool
}
func NewClientListener(listener net.Listener, quickClose bool) *ClientListener {
func NewClientListener(listener net.Listener, isHTTP bool) *ClientListener {
return &ClientListener{
rawListener: listener,
quickClose: quickClose,
isHTTP: isHTTP,
}
}
@@ -78,7 +78,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
}
}
return NewClientConn(conn, this.isTLS, this.quickClose, isInAllowList), nil
return NewClientConn(conn, this.isHTTP, this.isTLS, isInAllowList), nil
}
func (this *ClientListener) Close() error {

View File

@@ -55,3 +55,16 @@ func (this *ClientTLSConn) SetReadDeadline(t time.Time) error {
func (this *ClientTLSConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientTLSConn) SetIsWebsocket(isWebsocket bool) {
tlsConn, ok := this.rawConn.(*tls.Conn)
if ok {
var rawConn = tlsConn.NetConn()
if rawConn != nil {
clientConn, ok := rawConn.(*ClientConn)
if ok {
clientConn.SetIsWebsocket(isWebsocket)
}
}
}
}

View File

@@ -43,7 +43,11 @@ func (this *HTTPAccessLogQueue) Start() {
for range ticker.C {
err := this.loop()
if err != nil {
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
if rpc.IsConnError(err) {
remotelogs.Debug("ACCESS_LOG_QUEUE", err.Error())
} else {
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
}
}
}
}

View File

@@ -95,11 +95,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
numberCPU = 8
}
if maxConnections <= 0 {
maxConnections = numberCPU * 32
maxConnections = numberCPU * 64
}
if idleConns <= 0 {
idleConns = numberCPU * 8
idleConns = numberCPU * 16
}
// 可以判断为Ln节点请求

View File

@@ -237,6 +237,14 @@ func (this *HTTPRequest) Do() {
}
}
// UA名单
if !this.isSubRequest && this.web.UserAgent != nil && this.web.UserAgent.IsOn {
if this.doCheckUserAgent() {
this.doEnd()
return
}
}
// 访问控制
if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn {
if this.doAuth() {
@@ -526,6 +534,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
this.web.Referers = web.Referers
}
// user agent
if web.UserAgent != nil && (web.UserAgent.IsPrior || isTop) {
this.web.UserAgent = web.UserAgent
}
// request limit
if web.RequestLimit != nil && (web.RequestLimit.IsPrior || isTop) {
this.web.RequestLimit = web.RequestLimit
@@ -758,6 +771,8 @@ func (this *HTTPRequest) Format(source string) string {
return strconv.FormatInt(this.requestFromTime.Unix(), 10)
case "host":
return this.ReqHost
case "cname":
return this.ReqServer.CNameDomain
case "referer":
return this.RawReq.Referer()
case "referer.host":
@@ -1131,6 +1146,8 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
// 获取请求的客户端地址列表
func (this *HTTPRequest) requestRemoteAddrs() (result []string) {
result = append(result, this.requestRemoteAddr(true))
// X-Forwarded-For
var forwardedFor = this.RawReq.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
@@ -1552,7 +1569,7 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
}
// 是否已删除
if this.web.ResponseHeaderPolicy.ContainsDeletedHeader(header.Name) {
if this.web.RequestHeaderPolicy.ContainsDeletedHeader(header.Name) {
continue
}
@@ -1603,9 +1620,25 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
header.Del(k)
k = strings.ReplaceAll(k, "-Websocket-", "-WebSocket-")
header[k] = v
} else if k == "Www-Authenticate" {
} else if strings.HasPrefix(k, "Sec-Ch") {
header.Del(k)
header["WWW-Authenticate"] = v
k = strings.ReplaceAll(k, "Sec-Ch-Ua", "Sec-CH-UA")
header[k] = v
} else {
switch k {
case "Www-Authenticate":
header.Del(k)
header["WWW-Authenticate"] = v
case "A-Im":
header.Del(k)
header["A-IM"] = v
case "Content-Md5":
header.Del(k)
header["Content-MD5"] = v
case "Sec-Gpc":
header.Del(k)
header["Content-GPC"] = v
}
}
}
}
@@ -1674,6 +1707,36 @@ func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, stat
responseHeader[header.Name] = []string{headerValue}
}
}
// CORS
if this.web.ResponseHeaderPolicy.CORS != nil && this.web.ResponseHeaderPolicy.CORS.IsOn {
var corsConfig = this.web.ResponseHeaderPolicy.CORS
// Allow-Origin
if len(corsConfig.AllowOrigin) == 0 {
var origin = this.RawReq.Header.Get("Origin")
if len(origin) > 0 {
responseHeader.Set("Access-Control-Allow-Origin", origin)
}
} else {
responseHeader.Set("Access-Control-Allow-Origin", corsConfig.AllowOrigin)
}
// Allow-Methods
if len(corsConfig.AllowMethods) == 0 {
responseHeader.Set("Access-Control-Allow-Methods", "PUT, GET, POST, DELETE, HEAD, OPTIONS")
} else {
responseHeader.Set("Access-Control-Allow-Methods", strings.Join(corsConfig.AllowMethods, ", "))
}
// Max-Age
if corsConfig.MaxAge > 0 {
responseHeader.Set("Access-Control-Max-Age", types.String(corsConfig.MaxAge))
}
// Allow-Credentials
responseHeader.Set("Access-Control-Allow-Credentials", "true")
}
}
// HSTS
@@ -1705,10 +1768,10 @@ func (this *HTTPRequest) bytePool(contentLength int64) *utils.BytePool {
return utils.BytePool1k
}
if contentLength < 32768 { // 32K
return utils.BytePool4k
return utils.BytePool16k
}
if contentLength < 131072 { // 128K
return utils.BytePool16k
return utils.BytePool32k
}
return utils.BytePool32k
}

View File

@@ -295,34 +295,31 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
}
var pool = this.bytePool(fileSize)
var buf = pool.Get()
defer func() {
pool.Put(buf)
}()
// 读取Header
var headerBuf = []byte{}
var headerData = []byte{}
this.writer.SetSentHeaderBytes(reader.HeaderSize())
err = reader.ReadHeader(buf, func(n int) (goNext bool, err error) {
headerBuf = append(headerBuf, buf[:n]...)
var headerPool = this.bytePool(reader.HeaderSize())
var headerBuf = headerPool.Get()
err = reader.ReadHeader(headerBuf, func(n int) (goNext bool, err error) {
headerData = append(headerData, headerBuf[:n]...)
for {
nIndex := bytes.Index(headerBuf, []byte{'\n'})
nIndex := bytes.Index(headerData, []byte{'\n'})
if nIndex >= 0 {
row := headerBuf[:nIndex]
row := headerData[:nIndex]
spaceIndex := bytes.Index(row, []byte{':'})
if spaceIndex <= 0 {
return false, errors.New("invalid header '" + string(row) + "'")
}
this.writer.Header().Set(string(row[:spaceIndex]), string(row[spaceIndex+1:]))
headerBuf = headerBuf[nIndex+1:]
headerData = headerData[nIndex+1:]
} else {
break
}
}
return true, nil
})
headerPool.Put(headerBuf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
@@ -460,13 +457,16 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
respHeader.Set("Content-Length", strconv.FormatInt(ranges[0].Length(), 10))
this.writer.WriteHeader(http.StatusPartialContent)
err = reader.ReadBodyRange(buf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(buf[:n])
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
err = reader.ReadBodyRange(bodyBuf, ranges[0].Start(), ranges[0].End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(bodyBuf[:n])
if err != nil {
return false, errWritingToClient
}
return true, nil
})
pool.Put(bodyBuf)
if err != nil {
this.varMapping["cache.status"] = "MISS"
@@ -513,13 +513,16 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
}
err := reader.ReadBodyRange(buf, r.Start(), r.End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(buf[:n])
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
err := reader.ReadBodyRange(bodyBuf, r.Start(), r.End(), func(n int) (goNext bool, err error) {
_, err = this.writer.Write(bodyBuf[:n])
if err != nil {
return false, errWritingToClient
}
return true, nil
})
pool.Put(bodyBuf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
@@ -543,15 +546,18 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.writer.Prepare(resp, fileSize, reader.Status(), false)
this.writer.WriteHeader(reader.Status())
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
if storage.CanSendfile() {
if fp, canSendFile := this.writer.canSendfile(); canSendFile {
this.writer.sentBodyBytes, err = io.CopyBuffer(this.writer.rawWriter, fp, buf)
this.writer.sentBodyBytes, err = io.CopyBuffer(this.writer.rawWriter, fp, bodyBuf)
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf)
}
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf)
}
pool.Put(bodyBuf)
if err == io.EOF {
err = nil
}
@@ -622,7 +628,14 @@ func (this *HTTPRequest) tryPartialReader(storage caches.StorageInterface, key s
}()
// 检查范围
//const maxFirstSpan = 16 << 20 // TODO 可以在缓存策略中设置此值
for index, r := range ranges {
// 没有指定结束位置时,自动指定一个
/**if r.Start() >= 0 && r.End() == -1 {
if partialReader.MaxLength() > r.Start()+maxFirstSpan {
r[1] = r.Start() + maxFirstSpan
}
}**/
r1, ok := r.Convert(partialReader.MaxLength())
if !ok {
return nil, nil

View File

@@ -1,7 +1,11 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"strconv"
"strings"
@@ -13,7 +17,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if this.web.MergeSlashes {
urlPath = utils.CleanPath(urlPath)
}
fullURL := this.requestScheme() + "://" + this.ReqHost + urlPath
var fullURL = this.requestScheme() + "://" + this.ReqHost + urlPath
for _, u := range this.web.HostRedirects {
if !u.IsOn {
continue
@@ -21,11 +25,50 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if !u.MatchRequest(this.Format) {
continue
}
if u.MatchPrefix { // 匹配前缀
if strings.HasPrefix(fullURL, u.BeforeURL) {
afterURL := u.AfterURL
if u.KeepRequestURI {
afterURL += this.RawReq.URL.RequestURI()
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
if u.MatchPrefix { // 匹配前缀
if strings.HasPrefix(fullURL, u.BeforeURL) {
afterURL := u.AfterURL
if u.KeepRequestURI {
afterURL += this.RawReq.URL.RequestURI()
}
// 前后是否一致
if fullURL == afterURL {
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
} else if u.MatchRegexp { // 正则匹配
var reg = u.BeforeURLRegexp()
if reg == nil {
continue
}
var matches = reg.FindStringSubmatch(fullURL)
if len(matches) == 0 {
continue
}
var afterURL = u.AfterURL
for i, match := range matches {
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
}
var subNames = reg.SubexpNames()
if len(subNames) > 0 {
for _, subName := range subNames {
if len(subName) > 0 {
index := reg.SubexpIndex(subName)
if index > -1 {
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
}
}
}
}
// 前后是否一致
@@ -33,69 +76,6 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
}
} else if u.MatchRegexp { // 正则匹配
reg := u.BeforeURLRegexp()
if reg == nil {
continue
}
matches := reg.FindStringSubmatch(fullURL)
if len(matches) == 0 {
continue
}
afterURL := u.AfterURL
for i, match := range matches {
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
}
subNames := reg.SubexpNames()
if len(subNames) > 0 {
for _, subName := range subNames {
if len(subName) > 0 {
index := reg.SubexpIndex(subName)
if index > -1 {
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
}
}
}
}
// 前后是否一致
if fullURL == afterURL {
return false
}
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
} else { // 精准匹配
if fullURL == u.RealBeforeURL() {
// 前后是否一致
if fullURL == u.AfterURL {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
@@ -104,12 +84,127 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
} else { // 精准匹配
if fullURL == u.RealBeforeURL() {
// 前后是否一致
if fullURL == u.AfterURL {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypeDomain {
if len(u.DomainAfter) == 0 {
continue
}
var reqHost = this.ReqHost
// 忽略跳转前端口
if u.DomainBeforeIgnorePorts {
h, _, err := net.SplitHostPort(reqHost)
if err == nil && len(h) > 0 {
reqHost = h
}
}
// 如果跳转前后域名一致,则终止
if u.DomainAfter == reqHost {
return false
}
var scheme = u.DomainAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
if u.DomainsAll || configutils.MatchDomains(u.DomainsBefore, reqHost) {
var afterURL = scheme + "://" + u.DomainAfter + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
// 参数
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
if u.PortAfter <= 0 {
continue
}
var scheme = u.PortAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
reqHost, reqPort, _ := net.SplitHostPort(this.ReqHost)
if len(reqHost) == 0 {
reqHost = this.ReqHost
}
if len(reqPort) == 0 {
switch this.requestScheme() {
case "http":
reqPort = "80"
case "https":
reqPort = "443"
}
}
// 如果跳转前后端口一致,则终止
if reqPort == types.String(u.PortAfter) {
return false
}
var containsPort = false
if u.PortsAll {
containsPort = true
} else {
containsPort = u.ContainsPort(types.Int(reqPort))
}
if containsPort {
var newReqHost = reqHost
if !((scheme == "http" && u.PortAfter == 80) || (scheme == "https" && u.PortAfter == 443)) {
newReqHost += ":" + types.String(u.PortAfter)
}
var afterURL = scheme + "://" + newReqHost + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
}

View File

@@ -51,27 +51,46 @@ func (this *HTTPRequest) log() {
addr = addr[:index]
}
var serverGlobalConfig = this.nodeConfig.GlobalServerConfig
// 请求Cookie
var cookies = map[string]string{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
for _, cookie := range this.RawReq.Cookies() {
cookies[cookie.Name] = cookie.Value
var enableCookies = false
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableCookies {
enableCookies = true
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
for _, cookie := range this.RawReq.Cookies() {
cookies[cookie.Name] = cookie.Value
}
}
}
// 请求Header
var pbReqHeader = map[string]*pb.Strings{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
for k, v := range this.RawReq.Header {
pbReqHeader[k] = &pb.Strings{Values: v}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableRequestHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
// 是否只记录通用Header
var commonHeadersOnly = serverGlobalConfig != nil && serverGlobalConfig.HTTPAccessLog.CommonRequestHeadersOnly
for k, v := range this.RawReq.Header {
if commonHeadersOnly && !serverconfigs.IsCommonRequestHeader(k) {
continue
}
if !enableCookies && k == "Cookie" {
continue
}
pbReqHeader[k] = &pb.Strings{Values: v}
}
}
}
// 响应Header
var pbResHeader = map[string]*pb.Strings{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
for k, v := range this.writer.Header() {
pbResHeader[k] = &pb.Strings{Values: v}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableResponseHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
for k, v := range this.writer.Header() {
pbResHeader[k] = &pb.Strings{Values: v}
}
}
}

View File

@@ -12,6 +12,8 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
return
}
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
var refererURL = this.RawReq.Header.Get("Referer")
if len(refererURL) == 0 {
if this.web.Referers.MatchDomain(this.ReqHost, "") {
@@ -19,6 +21,7 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
}
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
@@ -31,6 +34,7 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
}
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
@@ -38,6 +42,7 @@ func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
if !this.web.Referers.MatchDomain(this.ReqHost, u.Host) {
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}

View File

@@ -12,5 +12,5 @@ func (this *HTTPRequest) doStat() {
// 内置的统计
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"))
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"), this.remoteAddr)
}

View File

@@ -0,0 +1,24 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
)
func (this *HTTPRequest) doCheckUserAgent() (shouldStop bool) {
if this.web.UserAgent == nil {
return
}
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
if !this.web.UserAgent.AllowRequest(this.RawReq) {
this.tags = append(this.tags, "userAgentCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The User-Agent has been blocked.", "当前访问已被UA名单拦截。")
return true
}
return
}

View File

@@ -70,6 +70,13 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
this.RawReq.Header.Set("Origin", newRequestOrigin)
}
// 获取当前连接
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
// 连接源站
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
@@ -102,6 +109,11 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return
}
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsWebsocket(true)
}
clientConn, _, err := this.writer.Hijack()
if err != nil || clientConn == nil {
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
@@ -115,8 +127,8 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
// 读取第一个响应
var respReader = NewWebsocketResponseReader(originConn)
resp, err := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
if err != nil {
if resp.Body != nil {
if err != nil || resp == nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}

View File

@@ -132,7 +132,7 @@ func (this *HTTPWriter) Prepare(resp *http.Response, size int64, status int, ena
this.req.web.RequestLimit != nil &&
this.req.web.RequestLimit.IsOn &&
this.req.web.RequestLimit.OutBandwidthPerConnBytes() > 0 {
this.writer = writers.NewRateLimitWriter(this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes())
this.writer = writers.NewRateLimitWriter(this.req.RawReq.Context(), this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes())
}
return
@@ -303,7 +303,19 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
if this.isPartial {
cacheKey += caches.SuffixPartial
}
cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), size, cacheRef.MaxSizeBytes(), this.isPartial)
// 待写入尺寸
var totalSize = size
if totalSize < 0 && this.isPartial {
var contentRange = resp.Header.Get("Content-Range")
if len(contentRange) > 0 {
_, partialTotalSize := httpRequestParseContentRangeHeader(contentRange)
if partialTotalSize > 0 {
totalSize = partialTotalSize
}
}
}
cacheWriter, err := storage.OpenWriter(cacheKey, expiresAt, this.StatusCode(), this.calculateHeaderLength(), totalSize, cacheRef.MaxSizeBytes(), this.isPartial)
if err != nil {
if err == caches.ErrEntityTooLarge && addStatusHeader {
this.Header().Set("X-Cache", "BYPASS, entity too large")
@@ -324,13 +336,19 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
}
// 写入Header
var headerBuf = utils.SharedBufferPool.Get()
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {
if this.isPartial && k == "Content-Type" && strings.Contains(v1, "multipart/byteranges") {
continue
}
_, err = cacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n"))
_, err = headerBuf.Write([]byte(k + ":" + v1 + "\n"))
if err != nil {
utils.SharedBufferPool.Put(headerBuf)
remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
_ = this.cacheWriter.Discard()
this.cacheWriter = nil
@@ -338,6 +356,14 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
}
}
}
_, err = cacheWriter.WriteHeader(headerBuf.Bytes())
utils.SharedBufferPool.Put(headerBuf)
if err != nil {
remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
_ = this.cacheWriter.Discard()
this.cacheWriter = nil
return
}
if this.isPartial {
// content-range
@@ -558,6 +584,11 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
return
}
// 分区内容不压缩,防止读取失败
if !this.compressionConfig.EnablePartialContent && this.StatusCode() == http.StatusPartialContent {
return
}
if this.compressionConfig.Level <= 0 {
return
}
@@ -627,16 +658,21 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
cacheKey += this.cacheReaderSuffix
}
compressionCacheWriter, err := this.cacheStorage.OpenWriter(cacheKey+caches.SuffixCompression+compressionEncoding, expiredAt, this.StatusCode(), -1, cacheRef.MaxSizeBytes(), false)
compressionCacheWriter, err := this.cacheStorage.OpenWriter(cacheKey+caches.SuffixCompression+compressionEncoding, expiredAt, this.StatusCode(), this.calculateHeaderLength(), -1, cacheRef.MaxSizeBytes(), false)
if err != nil {
return
}
// 写入Header
var headerBuffer = utils.SharedBufferPool.Get()
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {
_, err = compressionCacheWriter.WriteHeader([]byte(k + ":" + v1 + "\n"))
_, err = headerBuffer.Write([]byte(k + ":" + v1 + "\n"))
if err != nil {
utils.SharedBufferPool.Put(headerBuffer)
remotelogs.Error("HTTP_WRITER", "write compression cache failed: "+err.Error())
_ = compressionCacheWriter.Discard()
compressionCacheWriter = nil
@@ -645,6 +681,15 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
}
}
_, err = compressionCacheWriter.WriteHeader(headerBuffer.Bytes())
utils.SharedBufferPool.Put(headerBuffer)
if err != nil {
remotelogs.Error("HTTP_WRITER", "write compression cache failed: "+err.Error())
_ = compressionCacheWriter.Discard()
compressionCacheWriter = nil
return
}
if compressionCacheWriter != nil {
this.compressionCacheWriter = compressionCacheWriter
var teeWriter = writers.NewTeeWriterCloser(this.writer, compressionCacheWriter)
@@ -942,10 +987,14 @@ func (this *HTTPWriter) finishWebP() {
expiredAt = this.cacheWriter.ExpiredAt()
}
webpCacheWriter, _ = this.cacheStorage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), -1, -1, false)
webpCacheWriter, _ = this.cacheStorage.OpenWriter(cacheKey, expiredAt, this.StatusCode(), -1, -1, -1, false)
if webpCacheWriter != nil {
// 写入Header
for k, v := range this.Header() {
if k == "Set-Cookie" {
continue
}
// 这里是原始的数据,不需要内容编码
if k == "Content-Encoding" || k == "Transfer-Encoding" {
continue
@@ -1157,3 +1206,16 @@ func (this *HTTPWriter) finishRequest() {
_ = this.rawReader.Close()
}
}
// 计算Header长度
func (this *HTTPWriter) calculateHeaderLength() (result int) {
for k, v := range this.Header() {
if k == "Set-Cookie" || (this.isPartial && k == "Content-Range") {
continue
}
for _, v1 := range v {
result += len(k) + 1 /**:**/ + len(v1) + 1 /**\n**/
}
}
return
}

View File

@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return tlsPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
if err != nil {
return nil, err
}
@@ -160,7 +160,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
// 严格查找域名
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
group := this.Group
var group = this.Group
if group == nil {
return nil, ""
}
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, name
}
// 从Hello信息中获取服务名称
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
var serverName = clientInfo.ServerName
if len(serverName) == 0 {
var localAddr = clientInfo.Conn.LocalAddr()
if localAddr != nil {
tcpAddr, ok := localAddr.(*net.TCPAddr)
if ok {
serverName = tcpAddr.IP.String()
}
}
}
return serverName
}

View File

@@ -14,7 +14,7 @@ func TestBaseListener_FindServer(t *testing.T) {
sharedNodeConfig = &nodeconfigs.NodeConfig{}
var listener = &BaseListener{}
listener.Group = &serverconfigs.ServerAddressGroup{}
listener.Group = serverconfigs.NewServerAddressGroup("https://*:443")
for i := 0; i < 1_000_000; i++ {
var server = &serverconfigs.ServerConfig{
IsOn: true,

View File

@@ -18,6 +18,8 @@ import (
var httpErrorLogger = log.New(io.Discard, "", 0)
const HTTPIdleTimeout = 75 * time.Second
type contextKey struct {
key string
}
@@ -43,16 +45,12 @@ func (this *HTTPListener) Serve() error {
this.httpServer = &http.Server{
Addr: this.addr,
Handler: this,
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
WriteTimeout: 1 * time.Hour, // TODO 改成可以配置
IdleTimeout: 75 * time.Second, // TODO 改成可以配置
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
IdleTimeout: HTTPIdleTimeout, // TODO 改成可以配置
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateActive, http.StateIdle, http.StateHijacked:
// Nothing to do
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
}
@@ -116,8 +114,14 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
// ServerHTTP 处理HTTP请求
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
// 不支持Connect
if rawReq.Method == http.MethodConnect {
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
// 域名
var reqHost = rawReq.Host
var reqHost = strings.ToLower(strings.TrimRight(rawReq.Host, "."))
// TLS域名
if this.isIP(reqHost) {
@@ -175,6 +179,15 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
}
}
// 检查用户
if server != nil && server.UserId > 0 {
if !SharedUserManager.CheckUserServersIsEnabled(server.UserId) {
rawWriter.WriteHeader(http.StatusNotFound)
_, _ = rawWriter.Write([]byte("The site owner is unavailable."))
return
}
}
// 包装新请求对象
var req = &HTTPRequest{
RawReq: rawReq,

View File

@@ -92,7 +92,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// 是否已达到流量限制
if this.reachedTrafficLimit() {
if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {

View File

@@ -170,6 +170,11 @@ func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
return nil
}
// 检查用户状态
if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
return nil
}
n, cm, clientAddr, err := listener.ReadFrom(buffer)
if err != nil {
if this.isClosed {

View File

@@ -2,14 +2,15 @@ package nodes
import (
"bytes"
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/configs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
@@ -24,12 +25,13 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
_ "github.com/TeaOSLab/EdgeNode/internal/utils/agents" // 引入Agent管理器
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/andybalholm/brotli"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
@@ -38,6 +40,7 @@ import (
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"sort"
@@ -59,18 +62,27 @@ type Node struct {
sock *gosock.Sock
locker sync.Mutex
maxCPU int32
maxThreads int
timezone string
oldMaxCPU int32
oldMaxThreads int
oldTimezone string
oldHTTPCachePolicies []*serverconfigs.HTTPCachePolicy
oldHTTPFirewallPolicies []*firewallconfigs.HTTPFirewallPolicy
oldFirewallActions []*firewallconfigs.FirewallActionConfig
oldMetricItems []*serverconfigs.MetricItemConfig
updatingServerMap map[int64]*serverconfigs.ServerConfig
lastAPINodeVersion int64
lastAPINodeAddrs []string // 以前的API节点地址
lastTaskVersion int64
}
func NewNode() *Node {
return &Node{
sock: gosock.NewTmpSock(teaconst.ProcessName),
maxThreads: -1,
maxCPU: -1,
oldMaxThreads: -1,
oldMaxCPU: -1,
updatingServerMap: map[int64]*serverconfigs.ServerConfig{},
}
}
@@ -190,7 +202,7 @@ func (this *Node) Start() {
}
}
sharedNodeConfig = nodeConfig
this.onReload(nodeConfig)
this.onReload(nodeConfig, true)
// 发送事件
events.Notify(events.EventLoaded)
@@ -205,9 +217,7 @@ func (this *Node) Start() {
// 统计
goman.New(func() {
stats.SharedTrafficStatManager.Start(func() *nodeconfigs.NodeConfig {
return sharedNodeConfig
})
stats.SharedTrafficStatManager.Start()
})
goman.New(func() {
stats.SharedHTTPRequestStatManager.Start()
@@ -313,8 +323,9 @@ func (this *Node) loop() error {
return errors.New("create rpc client failed: " + err.Error())
}
var nodeCtx = rpcClient.Context()
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{})
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{
Version: this.lastTaskVersion,
})
if err != nil {
if rpc.IsConnError(err) && !Tea.IsTesting() {
return nil
@@ -322,15 +333,18 @@ func (this *Node) loop() error {
return errors.New("read node tasks failed: " + err.Error())
}
for _, task := range tasksResp.NodeTasks {
err := this.execTask(rpcClient, nodeCtx, task)
this.finishTask(task.Id, err)
err := this.execTask(rpcClient, task)
if !this.finishTask(task.Id, task.Version, err) {
// 防止失败的任务无法重试
break
}
}
return nil
}
// 执行任务
func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, task *pb.NodeTask) error {
func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
switch task.Type {
case "ipItemChanged":
// 防止阻塞
@@ -360,7 +374,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
return errors.New("reload common scripts failed: " + err.Error())
}
case "nodeLevelChanged":
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{})
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(rpcClient.Context(), &pb.FindNodeLevelInfoRequest{})
if err != nil {
return err
}
@@ -381,7 +395,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
sharedNodeConfig.ParentNodes = parentNodes
}
case "ddosProtectionChanged":
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(rpcClient.Context(), &pb.FindNodeDDoSProtectionRequest{})
if err != nil {
return err
}
@@ -409,7 +423,7 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
return nil
}
case "globalServerConfigChanged":
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(nodeCtx, &pb.FindNodeGlobalServerConfigRequest{})
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(rpcClient.Context(), &pb.FindNodeGlobalServerConfigRequest{})
if err != nil {
return err
}
@@ -430,6 +444,22 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
}
}
}
case "userServersStateChanged":
if task.UserId > 0 {
resp, err := rpcClient.UserRPC.CheckUserServersState(rpcClient.Context(), &pb.CheckUserServersStateRequest{UserId: task.UserId})
if err != nil {
return err
}
SharedUserManager.UpdateUserServersIsEnabled(task.UserId, resp.IsEnabled)
if resp.IsEnabled {
err = this.syncUserServersConfig(task.UserId)
if err != nil {
return err
}
}
}
default:
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
}
@@ -438,39 +468,44 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
}
// 标记任务完成
func (this *Node) finishTask(taskId int64, err error) {
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
if taskId <= 0 {
return
return true
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
logs.Println("[NODE]", "create rpc client failed: "+err.Error())
return
remotelogs.Debug("NODE", "create rpc client failed: "+err.Error())
return false
}
var nodeCtx = rpcClient.Context()
var isOk = taskErr == nil
if isOk && taskVersion > this.lastTaskVersion {
this.lastTaskVersion = taskVersion
}
var isOk = err == nil
var errMsg = ""
if err != nil {
errMsg = err.Error()
if taskErr != nil {
errMsg = taskErr.Error()
}
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
NodeTaskId: taskId,
IsOk: isOk,
Error: errMsg,
})
success = err == nil
if err != nil {
// 不需要上报到服务中心
// 连接错误不需要上报到服务中心
if rpc.IsConnError(err) {
logs.Println("[NODE]", "report task done failed: "+err.Error())
remotelogs.Debug("NODE", "report task done failed: "+err.Error())
} else {
remotelogs.Error("NODE", "report task done failed: "+err.Error())
}
}
return success
}
// 读取API配置
@@ -501,10 +536,8 @@ func (this *Node) syncConfig(taskVersion int64) error {
}
// 获取同步任务
var nodeCtx = rpcClient.Context()
// TODO 这里考虑只同步版本号有变更的
configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{
configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{
Version: -1, // 更新所有版本
Compress: true,
NodeTaskVersion: taskVersion,
@@ -575,7 +608,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
remotelogs.Println("NODE", "loading config ...")
}
this.onReload(nodeConfig)
this.onReload(nodeConfig, true)
// 发送事件
events.Notify(events.EventReload)
@@ -615,6 +648,36 @@ func (this *Node) syncServerConfig(serverId int64) error {
return nil
}
// 同步某个用户下的所有服务配置
func (this *Node) syncUserServersConfig(userId int64) error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
serverConfigsResp, err := rpcClient.ServerRPC.ComposeAllUserServersConfig(rpcClient.Context(), &pb.ComposeAllUserServersConfigRequest{
UserId: userId,
})
if err != nil {
return err
}
if len(serverConfigsResp.ServersConfigJSON) == 0 {
return nil
}
var serverConfigs = []*serverconfigs.ServerConfig{}
err = json.Unmarshal(serverConfigsResp.ServersConfigJSON, &serverConfigs)
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
for _, config := range serverConfigs {
this.updatingServerMap[config.Id] = config
}
return nil
}
// 启动同步计时器
func (this *Node) startSyncTimer() {
// TODO 这个时间间隔可以自行设置
@@ -676,12 +739,12 @@ func (this *Node) checkClusterConfig() error {
return err
}
logs.Println("[NODE]registering node to cluster ...")
remotelogs.Debug("NODE", "registering node to cluster ...")
resp, err := rpcClient.NodeRPC.RegisterClusterNode(rpcClient.ClusterContext(config.ClusterId, config.Secret), &pb.RegisterClusterNodeRequest{Name: HOSTNAME})
if err != nil {
return err
}
logs.Println("[NODE]registered successfully")
remotelogs.Debug("NODE", "registered successfully")
// 写入到配置文件中
if len(resp.Endpoints) == 0 {
@@ -689,8 +752,8 @@ func (this *Node) checkClusterConfig() error {
}
var apiConfig = &configs.APIConfig{
RPC: struct {
Endpoints []string `yaml:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate"`
Endpoints []string `yaml:"endpoints" json:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate" json:"disableUpdate"`
}{
Endpoints: resp.Endpoints,
DisableUpdate: false,
@@ -698,12 +761,12 @@ func (this *Node) checkClusterConfig() error {
NodeId: resp.UniqueId,
Secret: resp.Secret,
}
logs.Println("[NODE]writing 'configs/api.yaml' ...")
remotelogs.Debug("NODE", "writing 'configs/api.yaml' ...")
err = apiConfig.WriteFile(Tea.ConfigFile("api.yaml"))
if err != nil {
return err
}
logs.Println("[NODE]wrote 'configs/api.yaml' successfully")
remotelogs.Debug("NODE", "wrote 'configs/api.yaml' successfully")
return nil
}
@@ -817,16 +880,56 @@ func (this *Node) listenSock() error {
},
})
case "conns":
var addrs = []string{}
var connMaps = []maps.Map{}
var connMap = conns.SharedMap.AllConns()
for _, conn := range connMap {
addrs = append(addrs, conn.RemoteAddr().String())
var createdAt int64
var lastReadAt int64
var lastWriteAt int64
var lastErrString = ""
clientConn, ok := conn.(*ClientConn)
if ok {
createdAt = clientConn.CreatedAt()
lastReadAt = clientConn.LastReadAt()
lastWriteAt = clientConn.LastWriteAt()
var lastErr = clientConn.LastErr()
if lastErr != nil {
lastErrString = lastErr.Error()
}
}
var age int64 = -1
var lastReadAge int64 = -1
var lastWriteAge int64 = -1
var currentTime = time.Now().Unix()
if createdAt > 0 {
age = currentTime - createdAt
}
if lastReadAt > 0 {
lastReadAge = currentTime - lastReadAt
}
if lastWriteAt > 0 {
lastWriteAge = currentTime - lastWriteAt
}
connMaps = append(connMaps, maps.Map{
"addr": conn.RemoteAddr().String(),
"age": age,
"readAge": lastReadAge,
"writeAge": lastWriteAge,
"lastErr": lastErrString,
})
}
sort.Slice(connMaps, func(i, j int) bool {
var m1 = connMaps[i]
var m2 = connMaps[j]
return m1.GetInt64("age") < m2.GetInt64("age")
})
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{
"addrs": addrs,
"total": len(addrs),
"conns": connMaps,
"total": len(connMaps),
},
})
case "dropIP":
@@ -858,6 +961,11 @@ func (this *Node) listenSock() error {
} else {
_ = cmd.ReplyOk()
}
case "closeIP":
var m = maps.NewMap(cmd.Params)
var ip = m.GetString("ip")
conns.SharedMap.CloseIPConns(ip)
_ = cmd.ReplyOk()
case "removeIP":
var m = maps.NewMap(cmd.Params)
var ip = m.GetString("ip")
@@ -908,12 +1016,12 @@ func (this *Node) listenSock() error {
err := this.sock.Listen()
if err != nil {
logs.Println("NODE", err.Error())
remotelogs.Debug("NODE", err.Error())
}
})
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("NODE", "quit unix sock")
remotelogs.Debug("NODE", "quit unix sock")
_ = this.sock.Close()
})
@@ -921,97 +1029,165 @@ func (this *Node) listenSock() error {
}
// 重载配置调用
func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
func (this *Node) onReload(config *nodeconfigs.NodeConfig, reloadAll bool) {
nodeconfigs.ResetNodeConfig(config)
sharedNodeConfig = config
// 缓存策略
caches.SharedManager.MaxDiskCapacity = config.MaxCacheDiskCapacity
caches.SharedManager.MaxMemoryCapacity = config.MaxCacheMemoryCapacity
caches.SharedManager.DiskDir = config.CacheDiskDir
if len(config.HTTPCachePolicies) > 0 {
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
} else {
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
if reloadAll {
// 缓存策略
var subDirs = config.CacheDiskSubDirs
for _, subDir := range subDirs {
subDir.Path = filepath.Clean(subDir.Path)
}
if len(subDirs) > 0 {
sort.Slice(subDirs, func(i, j int) bool {
return subDirs[i].Path < subDirs[j].Path
})
}
var cachePoliciesChanged = !jsonutils.Equal(caches.SharedManager.MaxDiskCapacity, config.MaxCacheDiskCapacity) ||
!jsonutils.Equal(caches.SharedManager.MaxMemoryCapacity, config.MaxCacheMemoryCapacity) ||
!jsonutils.Equal(caches.SharedManager.MainDiskDir, config.CacheDiskDir) ||
!jsonutils.Equal(caches.SharedManager.SubDiskDirs, subDirs) ||
!jsonutils.Equal(this.oldHTTPCachePolicies, config.HTTPCachePolicies)
caches.SharedManager.MaxDiskCapacity = config.MaxCacheDiskCapacity
caches.SharedManager.MaxMemoryCapacity = config.MaxCacheMemoryCapacity
caches.SharedManager.MainDiskDir = config.CacheDiskDir
caches.SharedManager.SubDiskDirs = subDirs
if cachePoliciesChanged {
// copy
this.oldHTTPCachePolicies = []*serverconfigs.HTTPCachePolicy{}
err := jsonutils.Copy(&this.oldHTTPCachePolicies, config.HTTPCachePolicies)
if err != nil {
remotelogs.Error("NODE", "onReload: copy HTTPCachePolicies failed: "+err.Error())
}
// update
if len(config.HTTPCachePolicies) > 0 {
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
} else {
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
}
}
}
// WAF策略
waf.SharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
// 统计指标
metrics.SharedManager.Update(config.MetricItems)
// max cpu
if config.MaxCPU != this.maxCPU {
if config.MaxCPU > 0 && config.MaxCPU < int32(runtime.NumCPU()) {
runtime.GOMAXPROCS(int(config.MaxCPU))
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(config.MaxCPU)+"'")
} else {
var threads = runtime.NumCPU() * 4
runtime.GOMAXPROCS(threads)
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(threads)+"'")
}
this.maxCPU = config.MaxCPU
}
// max threads
if config.MaxThreads != this.maxThreads {
if config.MaxThreads > 0 {
debug.SetMaxThreads(config.MaxThreads)
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(config.MaxThreads)+"'")
} else {
debug.SetMaxThreads(nodeconfigs.DefaultMaxThreads)
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(nodeconfigs.DefaultMaxThreads)+"'")
}
this.maxThreads = config.MaxThreads
}
// timezone
var timeZone = config.TimeZone
if len(timeZone) == 0 {
timeZone = "Asia/Shanghai"
}
if this.timezone != timeZone {
location, err := time.LoadLocation(timeZone)
// 包含了服务里的WAF策略所以需要整体更新
var allFirewallPolicies = config.FindAllFirewallPolicies()
if !jsonutils.Equal(allFirewallPolicies, this.oldHTTPFirewallPolicies) {
// copy
this.oldHTTPFirewallPolicies = []*firewallconfigs.HTTPFirewallPolicy{}
err := jsonutils.Copy(&this.oldHTTPFirewallPolicies, allFirewallPolicies)
if err != nil {
remotelogs.Error("NODE", "[TIMEZONE]change time zone failed: "+err.Error())
return
remotelogs.Error("NODE", "onReload: copy HTTPFirewallPolicies failed: "+err.Error())
}
remotelogs.Println("NODE", "[TIMEZONE]change time zone to '"+timeZone+"'")
time.Local = location
this.timezone = timeZone
// update
waf.SharedWAFManager.UpdatePolicies(allFirewallPolicies)
}
// product information
if config.ProductConfig != nil {
teaconst.GlobalProductName = config.ProductConfig.Name
}
if reloadAll {
if !jsonutils.Equal(config.FirewallActions, this.oldFirewallActions) {
// copy
this.oldFirewallActions = []*firewallconfigs.FirewallActionConfig{}
err := jsonutils.Copy(&this.oldFirewallActions, config.FirewallActions)
if err != nil {
remotelogs.Error("NODE", "onReload: copy FirewallActionConfigs failed: "+err.Error())
}
// DNS resolver
if config.DNSResolver != nil {
var err error
switch config.DNSResolver.Type {
case nodeconfigs.DNSResolverTypeGoNative:
err = os.Setenv("GODEBUG", "netdns=go")
case nodeconfigs.DNSResolverTypeCGO:
err = os.Setenv("GODEBUG", "netdns=cgo")
default:
// update
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
}
// 统计指标
if !jsonutils.Equal(this.oldMetricItems, config.MetricItems) {
// copy
this.oldMetricItems = []*serverconfigs.MetricItemConfig{}
err := jsonutils.Copy(&this.oldMetricItems, config.MetricItems)
if err != nil {
remotelogs.Error("NODE", "onReload: copy MetricItemConfigs failed: "+err.Error())
}
// update
metrics.SharedManager.Update(config.MetricItems)
}
// max cpu
if config.MaxCPU != this.oldMaxCPU {
if config.MaxCPU > 0 && config.MaxCPU < int32(runtime.NumCPU()) {
runtime.GOMAXPROCS(int(config.MaxCPU))
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(config.MaxCPU)+"'")
} else {
var threads = runtime.NumCPU() * 4
runtime.GOMAXPROCS(threads)
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(threads)+"'")
}
this.oldMaxCPU = config.MaxCPU
}
// max threads
if config.MaxThreads != this.oldMaxThreads {
if config.MaxThreads > 0 {
debug.SetMaxThreads(config.MaxThreads)
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(config.MaxThreads)+"'")
} else {
debug.SetMaxThreads(nodeconfigs.DefaultMaxThreads)
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(nodeconfigs.DefaultMaxThreads)+"'")
}
this.oldMaxThreads = config.MaxThreads
}
// timezone
var timeZone = config.TimeZone
if len(timeZone) == 0 {
timeZone = "Asia/Shanghai"
}
if this.oldTimezone != timeZone {
location, err := time.LoadLocation(timeZone)
if err != nil {
remotelogs.Error("NODE", "[TIMEZONE]change time zone failed: "+err.Error())
return
}
remotelogs.Println("NODE", "[TIMEZONE]change time zone to '"+timeZone+"'")
time.Local = location
this.oldTimezone = timeZone
}
// product information
if config.ProductConfig != nil {
teaconst.GlobalProductName = config.ProductConfig.Name
}
// DNS resolver
if config.DNSResolver != nil {
var err error
switch config.DNSResolver.Type {
case nodeconfigs.DNSResolverTypeGoNative:
err = os.Setenv("GODEBUG", "netdns=go")
case nodeconfigs.DNSResolverTypeCGO:
err = os.Setenv("GODEBUG", "netdns=cgo")
default:
// 默认使用go原生
err = os.Setenv("GODEBUG", "netdns=go")
}
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
} else {
// 默认使用go原生
err = os.Setenv("GODEBUG", "netdns=go")
}
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
} else {
// 默认使用go原生
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
}
// API Node地址这里不限制是否为空因为在为空时仍然要有对应的处理
this.changeAPINodeAddrs(config.APINodeAddrs)
}
}
@@ -1047,7 +1223,7 @@ func (this *Node) reloadServer() {
}
}
this.onReload(newNodeConfig)
this.onReload(newNodeConfig, false)
err = sharedListenerManager.Start(newNodeConfig)
if err != nil {
@@ -1056,6 +1232,7 @@ func (this *Node) reloadServer() {
}
}
// 检查硬盘
func (this *Node) checkDisk() {
if runtime.GOOS != "linux" {
return
@@ -1076,3 +1253,69 @@ func (this *Node) checkDisk() {
}
}
}
// 检查API节点地址
func (this *Node) changeAPINodeAddrs(apiNodeAddrs []*serverconfigs.NetworkAddressConfig) {
var addrs = []string{}
for _, addr := range apiNodeAddrs {
err := addr.Init()
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: validate api node address '"+configutils.QuoteIP(addr.Host)+":"+addr.PortRange+"' failed: "+err.Error())
} else {
addrs = append(addrs, addr.FullAddresses()...)
}
}
sort.Strings(addrs)
if utils.EqualStrings(this.lastAPINodeAddrs, addrs) {
return
}
this.lastAPINodeAddrs = addrs
config, err := configs.LoadAPIConfig()
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: "+err.Error())
return
}
if config == nil {
return
}
var oldEndpoints = config.RPC.Endpoints
rpcClient, err := rpc.SharedRPC()
if err != nil {
return
}
if len(addrs) > 0 {
this.lastAPINodeVersion++
var v = this.lastAPINodeVersion
// 异步检测,防止阻塞
go func(v int64) {
// 测试新的API节点地址
if rpcClient.TestEndpoints(addrs) {
config.RPC.Endpoints = addrs
} else {
config.RPC.Endpoints = oldEndpoints
this.lastAPINodeAddrs = nil // 恢复为空,以便于下次更新重试
}
// 检查测试中间有无新的变更
if v != this.lastAPINodeVersion {
return
}
err = rpcClient.UpdateConfig(config)
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
}
}(v)
return
}
err = rpcClient.UpdateConfig(config)
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
}
}

View File

@@ -31,12 +31,15 @@ type NodeStatusExecutor struct {
cpuLogicalCount int
cpuPhysicalCount int
apiCallStat *rpc.CallStat
ticker *time.Ticker
}
func NewNodeStatusExecutor() *NodeStatusExecutor {
return &NodeStatusExecutor{
ticker: time.NewTicker(30 * time.Second),
ticker: time.NewTicker(30 * time.Second),
apiCallStat: rpc.NewCallStat(10),
}
}
@@ -78,6 +81,11 @@ func (this *NodeStatusExecutor) update() {
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
status.TrafficInBytes = teaconst.InTrafficBytes
status.TrafficOutBytes = teaconst.OutTrafficBytes
apiSuccessPercent, apiAvgCostSeconds := this.apiCallStat.Sum()
status.APISuccessPercent = apiSuccessPercent
status.APIAvgCostSeconds = apiAvgCostSeconds
var localFirewall = firewalls.Firewall()
if localFirewall != nil && !localFirewall.IsMock() {
status.LocalFirewallName = localFirewall.Name()
@@ -125,9 +133,13 @@ func (this *NodeStatusExecutor) update() {
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
return
}
var before = time.Now()
_, err = rpcClient.NodeRPC.UpdateNodeStatus(rpcClient.Context(), &pb.UpdateNodeStatusRequest{
StatusJSON: jsonData,
})
var costSeconds = time.Since(before).Seconds()
this.apiCallStat.Add(err == nil, costSeconds)
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Warn("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
@@ -140,7 +152,7 @@ func (this *NodeStatusExecutor) update() {
// 更新CPU
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
duration := time.Duration(0)
var duration = time.Duration(0)
if this.isFirstTime {
duration = 100 * time.Millisecond
}
@@ -195,8 +207,8 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
})
// 当前TeaWeb所在的fs
rootFS := ""
rootTotal := uint64(0)
var rootFS = ""
var rootTotal = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
@@ -210,9 +222,9 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
}
}
total := rootTotal
totalUsage := uint64(0)
maxUsage := float64(0)
var total = rootTotal
var totalUsage = uint64(0)
var maxUsage = float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
continue
@@ -252,7 +264,7 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
// 缓存空间
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
var result = []maps.Map{}
cachePaths := caches.SharedManager.FindAllCachePaths()
var cachePaths = caches.SharedManager.FindAllCachePaths()
for _, path := range cachePaths {
var stat unix.Statfs_t
err := unix.Statfs(path, &stat)

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package nodes

View File

@@ -1,23 +1,15 @@
package nodes
import (
"context"
"crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/configs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/url"
"sort"
"strings"
"sync"
"time"
)
@@ -64,6 +56,9 @@ func (this *SyncAPINodesTask) Stop() {
}
func (this *SyncAPINodesTask) Loop() error {
// 如果有节点定制的API节点地址
var hasCustomizedAPINodeAddrs = sharedNodeConfig != nil && len(sharedNodeConfig.APINodeAddrs) > 0
config, err := configs.LoadAPIConfig()
if err != nil {
return err
@@ -96,21 +91,25 @@ func (this *SyncAPINodesTask) Loop() error {
}
// 和现有的对比
if this.isSame(newEndpoints, config.RPC.Endpoints) {
if utils.EqualStrings(newEndpoints, config.RPC.Endpoints) {
return nil
}
// 测试是否有API节点可用
var hasOk = this.testEndpoints(newEndpoints)
var hasOk = rpcClient.TestEndpoints(newEndpoints)
if !hasOk {
return nil
}
// 修改RPC对象配置
config.RPC.Endpoints = newEndpoints
err = rpcClient.UpdateConfig(config)
if err != nil {
return err
// 更新当前RPC
if !hasCustomizedAPINodeAddrs {
err = rpcClient.UpdateConfig(config)
if err != nil {
return err
}
}
// 保存到文件
@@ -121,53 +120,3 @@ func (this *SyncAPINodesTask) Loop() error {
return nil
}
func (this *SyncAPINodesTask) isSame(endpoints1 []string, endpoints2 []string) bool {
sort.Strings(endpoints1)
sort.Strings(endpoints2)
return strings.Join(endpoints1, "&") == strings.Join(endpoints2, "&")
}
func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
if len(endpoints) == 0 {
return false
}
var wg = sync.WaitGroup{}
wg.Add(len(endpoints))
var ok = false
for _, endpoint := range endpoints {
go func(endpoint string) {
defer wg.Done()
u, err := url.Parse(endpoint)
if err != nil {
return
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancelFunc()
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
} else if u.Scheme == "https" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), grpc.WithBlock())
}
if err != nil {
return
}
_ = conn.Close()
ok = true
}(endpoint)
}
wg.Wait()
return ok
}

View File

@@ -0,0 +1,53 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"sync"
)
var SharedUserManager = NewUserManager()
type User struct {
ServersEnabled bool
}
type UserManager struct {
userMap map[int64]*User // id => *User
locker sync.RWMutex
}
func NewUserManager() *UserManager {
return &UserManager{
userMap: map[int64]*User{},
}
}
func (this *UserManager) UpdateUserServersIsEnabled(userId int64, isEnabled bool) {
this.locker.Lock()
u, ok := this.userMap[userId]
if ok {
u.ServersEnabled = isEnabled
} else {
u = &User{ServersEnabled: isEnabled}
this.userMap[userId] = u
}
this.locker.Unlock()
}
func (this *UserManager) CheckUserServersIsEnabled(userId int64) (isEnabled bool) {
if userId <= 0 {
return true
}
this.locker.RLock()
u, ok := this.userMap[userId]
if ok {
isEnabled = u.ServersEnabled
} else {
isEnabled = true
}
this.locker.RUnlock()
return
}

View File

@@ -57,6 +57,79 @@ func TestRegexp_ParseKeywords(t *testing.T) {
}
}
func TestRegexp_Special(t *testing.T) {
var unescape = func(v string) string {
//replace urlencoded characters
var chars = [][2]string{
{`\s`, `(\s|%09|%0A|\+)`},
{`\(`, `(\(|%28)`},
{`=`, `(=|%3D)`},
{`<`, `(<|%3C)`},
{`\*`, `(\*|%2A)`},
{`\\`, `(\\|%2F)`},
{`!`, `(!|%21)`},
{`/`, `(/|%2F)`},
{`;`, `(;|%3B)`},
{`\+`, `(\+|%20)`},
}
for _, c := range chars {
if !strings.Contains(v, c[0]) {
continue
}
var pieces = strings.Split(v, c[0])
// 修复piece中错误的\
for pieceIndex, piece := range pieces {
var l = len(piece)
if l == 0 {
continue
}
if piece[l-1] != '\\' {
continue
}
// 计算\的数量
var countBackSlashes = 0
for i := l - 1; i >= 0; i-- {
if piece[i] == '\\' {
countBackSlashes++
} else {
break
}
}
if countBackSlashes%2 == 1 {
// 去掉最后一个
pieces[pieceIndex] = piece[:len(piece)-1]
}
}
v = strings.Join(pieces, c[1])
}
return v
}
for _, s := range []string{
`\\s`,
`\s\W`,
`aaaa/\W`,
`aaaa\/\W`,
`aaaa\=\W`,
`aaaa\\=\W`,
`aaaa\\\=\W`,
`aaaa\\\\=\W`,
} {
var es = unescape(s)
t.Log(s, "=>", es)
_, err := re.Compile(es)
if err != nil {
t.Fatal(err)
}
}
}
func TestRegexp_ParseKeywords2(t *testing.T) {
var a = assert.NewAssertion(t)

View File

@@ -17,7 +17,7 @@ import (
"time"
)
var logChan = make(chan *pb.NodeLog, 1024)
var logChan = make(chan *pb.NodeLog, 64) // 队列数量不需要太长,因为日志通常仅仅为调试用
func init() {
// 定期上传日志

71
internal/rpc/call_stat.go Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package rpc
import (
"sync"
)
type callStatItem struct {
ok bool
costSeconds float64
}
type CallStat struct {
size int
items []*callStatItem
locker sync.Mutex
}
func NewCallStat(size int) *CallStat {
return &CallStat{
size: size,
}
}
func (this *CallStat) Add(ok bool, costSeconds float64) {
var size = this.size
if size <= 0 {
size = 10
}
this.locker.Lock()
this.items = append(this.items, &callStatItem{
ok: ok,
costSeconds: costSeconds,
})
if len(this.items) > size {
this.items = this.items[1:]
}
this.locker.Unlock()
}
func (this *CallStat) Sum() (successPercent float64, avgCostSeconds float64) {
this.locker.Lock()
defer this.locker.Unlock()
var size = this.size
if size <= 0 {
size = 10
}
var totalItems = len(this.items)
if totalItems <= size/2 /** 低于一半的采样率,不计入统计 **/ {
successPercent = 100
return
}
var totalOkItems = 0
var totalCostSeconds float64
for _, item := range this.items {
if item.ok {
totalOkItems++
}
totalCostSeconds += item.costSeconds
}
successPercent = float64(totalOkItems) * 100 / float64(totalItems)
avgCostSeconds = totalCostSeconds / float64(totalItems)
return
}

View File

@@ -0,0 +1,19 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package rpc_test
import (
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"testing"
)
func TestNewCallStat(t *testing.T) {
var stat = rpc.NewCallStat(10)
stat.Add(true, 1)
stat.Add(true, 2)
stat.Add(true, 3)
stat.Add(false, 4)
stat.Add(true, 0)
stat.Add(true, 1)
t.Log(stat.Sum())
}

View File

@@ -49,6 +49,8 @@ type RPCClient struct {
FirewallRPC pb.FirewallServiceClient
SSLCertRPC pb.SSLCertServiceClient
ScriptRPC pb.ScriptServiceClient
UserRPC pb.UserServiceClient
ClientAgentIPRPC pb.ClientAgentIPServiceClient
}
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
@@ -81,6 +83,8 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
client.FirewallRPC = pb.NewFirewallServiceClient(client)
client.SSLCertRPC = pb.NewSSLCertServiceClient(client)
client.ScriptRPC = pb.NewScriptServiceClient(client)
client.UserRPC = pb.NewUserServiceClient(client)
client.ClientAgentIPRPC = pb.NewClientAgentIPServiceClient(client)
err := client.init()
if err != nil {
@@ -92,7 +96,6 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
// Context 节点上下文信息
func (this *RPCClient) Context() context.Context {
var ctx = context.Background()
var m = maps.Map{
"timestamp": time.Now().Unix(),
"type": "node",
@@ -109,6 +112,8 @@ func (this *RPCClient) Context() context.Context {
return context.Background()
}
var token = base64.StdEncoding.EncodeToString(data)
var ctx = context.Background()
ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", this.apiConfig.NodeId, "token", token)
return ctx
}
@@ -157,6 +162,64 @@ func (this *RPCClient) UpdateConfig(config *configs.APIConfig) error {
return err
}
// TestEndpoints 测试Endpoints是否可用
func (this *RPCClient) TestEndpoints(endpoints []string) bool {
if len(endpoints) == 0 {
return false
}
var wg = sync.WaitGroup{}
wg.Add(len(endpoints))
var ok = false
for _, endpoint := range endpoints {
go func(endpoint string) {
defer wg.Done()
u, err := url.Parse(endpoint)
if err != nil {
return
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancelFunc()
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
} else if u.Scheme == "https" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), grpc.WithBlock())
} else {
return
}
if err != nil {
return
}
if conn == nil {
return
}
defer func() {
_ = conn.Close()
}()
var pingService = pb.NewPingServiceClient(conn)
_, err = pingService.Ping(this.Context(), &pb.PingRequest{})
if err != nil {
return
}
ok = true
}(endpoint)
}
wg.Wait()
return ok
}
// 初始化
func (this *RPCClient) init() error {
// 重新连接
@@ -201,12 +264,18 @@ func (this *RPCClient) pickConn() *grpc.ClientConn {
defer this.locker.Unlock()
// 检查连接状态
if len(this.conns) > 0 {
var availableConns = []*grpc.ClientConn{}
var countConns = len(this.conns)
if countConns > 0 {
if countConns == 1 {
return this.conns[0]
}
for _, stateArray := range [][2]connectivity.State{
{connectivity.Ready, connectivity.Idle}, // 优先Ready和Idle
{connectivity.Connecting, connectivity.Connecting},
{connectivity.TransientFailure, connectivity.TransientFailure},
} {
var availableConns = []*grpc.ClientConn{}
for _, conn := range this.conns {
var state = conn.GetState()
if state == stateArray[0] || state == stateArray[1] {
@@ -217,26 +286,6 @@ func (this *RPCClient) pickConn() *grpc.ClientConn {
return this.randConn(availableConns)
}
}
if len(availableConns) > 0 {
return this.randConn(availableConns)
}
// 关闭
for _, conn := range this.conns {
_ = conn.Close()
}
}
// 重新初始化
err := this.init()
if err != nil {
// 错误提示已经在构造对象时打印过,所以这里不再重复打印
return nil
}
if len(this.conns) == 0 {
return nil
}
return this.randConn(this.conns)
@@ -245,14 +294,15 @@ func (this *RPCClient) pickConn() *grpc.ClientConn {
func (this *RPCClient) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
var conn = this.pickConn()
if conn == nil {
return errors.New("can not get available grpc connection")
return errors.New("could not get available grpc connection")
}
return conn.Invoke(ctx, method, args, reply, opts...)
}
func (this *RPCClient) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
var conn = this.pickConn()
if conn == nil {
return nil, errors.New("can not get available grpc connection")
return nil, errors.New("could not get available grpc connection")
}
return conn.NewStream(ctx, desc, method, opts...)
}

62
internal/rpc/rpc_test.go Normal file
View File

@@ -0,0 +1,62 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package rpc_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
_ "github.com/iwind/TeaGo/bootstrap"
timeutil "github.com/iwind/TeaGo/utils/time"
"sync"
"testing"
"time"
)
func TestRPCConcurrentCall(t *testing.T) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)
}
var before = time.Now()
defer func() {
t.Log("cost:", time.Since(before).Seconds()*1000, "ms")
}()
var concurrent = 3
var wg = sync.WaitGroup{}
wg.Add(concurrent)
for i := 0; i < concurrent; i++ {
go func() {
defer wg.Done()
_, err = rpcClient.NodeRPC.FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{})
if err != nil {
t.Log(err)
}
}()
}
wg.Wait()
}
func TestRPC_Retry(t *testing.T) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)
}
var ticker = time.NewTicker(1 * time.Second)
for range ticker.C {
go func() {
_, err = rpcClient.NodeRPC.FindCurrentNodeConfig(rpcClient.Context(), &pb.FindCurrentNodeConfigRequest{})
if err != nil {
t.Log(timeutil.Format("H:i:s"), err)
} else {
t.Log(timeutil.Format("H:i:s"), "success")
}
}()
}
}

View File

@@ -3,6 +3,7 @@
package stats
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
@@ -42,6 +43,8 @@ type BandwidthStat struct {
type BandwidthStatManager struct {
m map[string]*BandwidthStat // key => *BandwidthStat
pbStats []*pb.ServerBandwidthStat
lastTime string // 上一次执行的时间
ticker *time.Ticker
@@ -65,6 +68,12 @@ func (this *BandwidthStatManager) Start() {
}
func (this *BandwidthStatManager) Loop() error {
var regionId int64
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
regionId = nodeConfig.RegionId
}
var now = time.Now()
var day = timeutil.Format("Ymd", now)
var currentTime = timeutil.FormatTime("Hi", now.Unix()/300*300)
@@ -76,16 +85,29 @@ func (this *BandwidthStatManager) Loop() error {
var pbStats = []*pb.ServerBandwidthStat{}
// 历史未提交记录
if len(this.pbStats) > 0 {
var expiredTime = timeutil.FormatTime("Hi", time.Now().Unix()-1200) // 只保留20分钟
for _, stat := range this.pbStats {
if stat.TimeAt > expiredTime {
pbStats = append(pbStats, stat)
}
}
this.pbStats = nil
}
this.locker.Lock()
for key, stat := range this.m {
if stat.Day < day || stat.TimeAt < currentTime {
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
NodeRegionId: regionId,
})
delete(this.m, key)
}
@@ -100,6 +122,8 @@ func (this *BandwidthStatManager) Loop() error {
}
_, err = rpcClient.ServerBandwidthStatRPC.UploadServerBandwidthStats(rpcClient.Context(), &pb.UploadServerBandwidthStatsRequest{ServerBandwidthStats: pbStats})
if err != nil {
this.pbStats = pbStats
return err
}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
@@ -146,11 +147,16 @@ func (this *HTTPRequestStatManager) AddRemoteAddr(serverId int64, remoteAddr str
}
// AddUserAgent 添加UserAgent
func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent string) {
func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent string, ip string) {
if len(userAgent) == 0 {
return
}
// 是否包含一些知名Agent
if len(userAgent) > 0 && len(ip) > 0 && agents.IsAgentFromUserAgent(userAgent) {
agents.SharedQueue.Push(ip)
}
select {
case this.userAgentChan <- strconv.FormatInt(serverId, 10) + "@" + userAgent:
default:

View File

@@ -37,11 +37,11 @@ func TestHTTPRequestStatManager_Loop_Region(t *testing.T) {
func TestHTTPRequestStatManager_Loop_UserAgent(t *testing.T) {
var manager = NewHTTPRequestStatManager()
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/76 Safari/537.36")
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; rv:11.0) like Gecko")
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/76 Safari/537.36", "")
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; rv:11.0) like Gecko", "")
err := manager.Loop()
if err != nil {
t.Fatal(err)

View File

@@ -31,19 +31,33 @@ type TrafficItem struct {
CheckingTrafficLimit bool
}
func (this *TrafficItem) Add(anotherItem *TrafficItem) {
this.Bytes += anotherItem.Bytes
this.CachedBytes += anotherItem.CachedBytes
this.CountRequests += anotherItem.CountRequests
this.CountCachedRequests += anotherItem.CountCachedRequests
this.CountAttackRequests += anotherItem.CountAttackRequests
this.AttackBytes += anotherItem.AttackBytes
}
const trafficStatsMaxLife = 1200 // 最大只保存20分钟内的数据
// TrafficStatManager 区域流量统计
type TrafficStatManager struct {
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
locker sync.Mutex
configFunc func() *nodeconfigs.NodeConfig
pbItems []*pb.ServerDailyStat
pbDomainItems []*pb.UploadServerDailyStatsRequest_DomainStat
locker sync.Mutex
totalRequests int64
}
// NewTrafficStatManager 获取新对象
func NewTrafficStatManager() *TrafficStatManager {
manager := &TrafficStatManager{
var manager = &TrafficStatManager{
itemMap: map[string]*TrafficItem{},
domainsMap: map[string]*TrafficItem{},
}
@@ -52,9 +66,7 @@ func NewTrafficStatManager() *TrafficStatManager {
}
// Start 启动自动任务
func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) {
this.configFunc = configFunc
func (this *TrafficStatManager) Start() {
// 上传请求总数
var monitorTicker = time.NewTicker(1 * time.Minute)
events.OnKey(events.EventQuit, this, func() {
@@ -70,7 +82,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
})
// 上传统计数据
duration := 5 * time.Minute
var duration = 5 * time.Minute
if Tea.IsTesting() {
// 测试环境缩短上传时间,方便我们调试
duration = 30 * time.Second
@@ -143,9 +155,10 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
// Upload 上传流量
func (this *TrafficStatManager) Upload() error {
var config = this.configFunc()
if config == nil {
return nil
var regionId int64
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
regionId = nodeConfig.RegionId
}
client, err := rpc.SharedRPC()
@@ -154,10 +167,14 @@ func (this *TrafficStatManager) Upload() error {
}
this.locker.Lock()
var itemMap = this.itemMap
var domainMap = this.domainsMap
// reset
this.itemMap = map[string]*TrafficItem{}
this.domainsMap = map[string]*TrafficItem{}
this.locker.Unlock()
// 服务统计
@@ -174,7 +191,7 @@ func (this *TrafficStatManager) Upload() error {
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
ServerId: serverId,
RegionId: config.RegionId,
NodeRegionId: regionId,
Bytes: item.Bytes,
CachedBytes: item.CachedBytes,
CountRequests: item.CountRequests,
@@ -186,9 +203,6 @@ func (this *TrafficStatManager) Upload() error {
CreatedAt: timestamp,
})
}
if len(pbServerStats) == 0 {
return nil
}
// 域名统计
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
@@ -210,9 +224,40 @@ func (this *TrafficStatManager) Upload() error {
})
}
// 历史未提交记录
if len(this.pbItems) > 0 || len(this.pbDomainItems) > 0 {
var expiredAt = time.Now().Unix() - 1200 // 只保留20分钟
for _, item := range this.pbItems {
if item.CreatedAt > expiredAt {
pbServerStats = append(pbServerStats, item)
}
}
this.pbItems = nil
for _, item := range this.pbDomainItems {
if item.CreatedAt > expiredAt {
pbDomainStats = append(pbDomainStats, item)
}
}
this.pbDomainItems = nil
}
if len(pbServerStats) == 0 && len(pbDomainStats) == 0 {
return nil
}
_, err = client.ServerDailyStatRPC.UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{
Stats: pbServerStats,
DomainStats: pbDomainStats,
})
return err
if err != nil {
// 加回历史记录
this.pbItems = pbServerStats
this.pbDomainItems = pbDomainStats
return err
}
return nil
}

View File

@@ -73,6 +73,15 @@ func (this *UserAgentParser) Parse(userAgent string) (result UserAgentParserResu
result.BrowserName, result.BrowserVersion = this.parser.Browser()
result.IsMobile = this.parser.Mobile()
// 忽略特殊字符
if len(result.BrowserName) > 0 {
for _, r := range result.BrowserName {
if r == '$' || r == '"' || r == '\'' || r == '<' || r == '>' || r == ')' {
return
}
}
}
if this.cacheCursor == 0 {
this.cacheMap1[userAgent] = result
if len(this.cacheMap1) >= this.maxCacheItems {

View File

@@ -0,0 +1,39 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"regexp"
"strings"
)
type Agent struct {
Code string
Keywords []string // user agent keywords
suffixes []string // PTR suffixes
reg *regexp.Regexp
}
func NewAgent(code string, suffixes []string, reg *regexp.Regexp, keywords []string) *Agent {
return &Agent{
Code: code,
suffixes: suffixes,
reg: reg,
Keywords: keywords,
}
}
func (this *Agent) Match(ptr string) bool {
if len(this.suffixes) > 0 {
for _, suffix := range this.suffixes {
if strings.HasSuffix(ptr, suffix) {
return true
}
}
}
if this.reg != nil {
return this.reg.MatchString(ptr)
}
return false
}

View File

@@ -0,0 +1,9 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
type AgentIP struct {
Id int64
IP string
AgentCode string
}

View File

@@ -0,0 +1,31 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import "strings"
var AllAgents = []*Agent{
NewAgent("baidu", []string{".baidu.com."}, nil, []string{"Baidu"}),
NewAgent("google", []string{".googlebot.com."}, nil, []string{"Google"}),
NewAgent("bing", []string{".search.msn.com."}, nil, []string{"bingbot"}),
NewAgent("sogou", []string{".sogou.com."}, nil, []string{"Sogou"}),
NewAgent("youdao", []string{".163.com."}, nil, []string{"Youdao"}),
NewAgent("yahoo", []string{".yahoo.com."}, nil, []string{"Yahoo"}),
NewAgent("bytedance", []string{".bytedance.com."}, nil, []string{"Bytespider"}),
NewAgent("sm", []string{".sm.cn."}, nil, []string{"YisouSpider"}),
NewAgent("yandex", []string{".yandex.com.", ".yndx.net."}, nil, []string{"Yandex"}),
NewAgent("semrush", []string{".semrush.com."}, nil, []string{"SEMrush"}),
}
func IsAgentFromUserAgent(userAgent string) bool {
for _, agent := range AllAgents {
if len(agent.Keywords) > 0 {
for _, keyword := range agent.Keywords {
if strings.Contains(userAgent, keyword) {
return true
}
}
}
}
return false
}

View File

@@ -0,0 +1,19 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"testing"
)
func TestIsAgentFromUserAgent(t *testing.T) {
t.Log(agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;) AppleWebKit/534.46 (KHTML,like Gecko) Version/5.1 Mobile Safari/10600.6.3 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)"))
t.Log(agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;)"))
}
func BenchmarkIsAgentFromUserAgent(b *testing.B) {
for i := 0; i < b.N; i++ {
agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;) AppleWebKit/534.46 (KHTML,like Gecko) Version/5.1 Mobile Safari/10600.6.3 (compatible; Yaho)")
}
}

156
internal/utils/agents/db.go Normal file
View File

@@ -0,0 +1,156 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package agents
import (
"database/sql"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
"path/filepath"
)
const (
tableAgentIPs = "agentIPs"
)
type DB struct {
db *sql.DB
path string
insertAgentIPStmt *sql.Stmt
listAgentIPsStmt *sql.Stmt
}
func NewDB(path string) *DB {
var db = &DB{path: path}
events.On(events.EventQuit, func() {
_ = db.Close()
})
return db
}
func (this *DB) Init() error {
// 检查目录是否存在
var dir = filepath.Dir(this.path)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
remotelogs.Println("DB", "create database dir '"+dir+"'")
}
// TODO 思考 data.db 的数据安全性
db, err := sql.Open("sqlite3", "file:"+this.path+"?cache=shared&mode=rwc&_journal_mode=WAL")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
/**_, err = db.Exec("VACUUM")
if err != nil {
return err
}**/
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableAgentIPs + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"ip" varchar(64),
"agentCode" varchar(128)
);`)
if err != nil {
return err
}
// 预编译语句
// agent ip record statements
this.insertAgentIPStmt, err = db.Prepare(`INSERT INTO "` + tableAgentIPs + `" ("id", "ip", "agentCode") VALUES (?, ?, ?)`)
if err != nil {
return err
}
this.listAgentIPsStmt, err = db.Prepare(`SELECT "id", "ip", "agentCode" FROM "` + tableAgentIPs + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.db = db
return nil
}
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
_, err := this.insertAgentIPStmt.Exec(ipId, ip, agentCode)
if err != nil {
return err
}
return nil
}
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*AgentIP, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listAgentIPsStmt.Query(size, offset)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var agentIP = &AgentIP{}
err = rows.Scan(&agentIP.Id, &agentIP.IP, &agentIP.AgentCode)
if err != nil {
return nil, err
}
agentIPs = append(agentIPs, agentIP)
}
return
}
func (this *DB) Close() error {
if this.db == nil {
return nil
}
for _, stmt := range []*sql.Stmt{
this.insertAgentIPStmt,
this.listAgentIPsStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
// 打印日志
func (this *DB) log(args ...any) {
if !Tea.IsTesting() {
return
}
if len(args) == 0 {
return
}
args[0] = "[" + types.String(args[0]) + "]"
log.Println(args...)
}

View File

@@ -0,0 +1,54 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeNode/internal/zero"
"sync"
)
type IPCacheMap struct {
m map[string]zero.Zero
list []string
locker sync.RWMutex
maxLen int
}
func NewIPCacheMap(maxLen int) *IPCacheMap {
if maxLen <= 0 {
maxLen = 65535
}
return &IPCacheMap{
m: map[string]zero.Zero{},
maxLen: maxLen,
}
}
func (this *IPCacheMap) Add(ip string) {
this.locker.Lock()
defer this.locker.Unlock()
// 是否已经存在
_, ok := this.m[ip]
if ok {
return
}
// 超出长度删除第一个
if len(this.list) >= this.maxLen {
delete(this.m, this.list[0])
this.list = this.list[1:]
}
// 加入新数据
this.m[ip] = zero.Zero{}
this.list = append(this.list, ip)
}
func (this *IPCacheMap) Contains(ip string) bool {
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.m[ip]
return ok
}

View File

@@ -0,0 +1,33 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestNewIPCacheMap(t *testing.T) {
var cacheMap = NewIPCacheMap(3)
t.Log("====")
cacheMap.Add("1")
cacheMap.Add("2")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("3")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("4")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("3")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
}

View File

@@ -0,0 +1,200 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"sync"
"time"
)
var SharedManager = NewManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedManager.Start()
})
})
}
// Manager Agent管理器
type Manager struct {
ipMap map[string]string // ip => agentCode
locker sync.RWMutex
db *DB
lastId int64
}
func NewManager() *Manager {
return &Manager{
ipMap: map[string]string{},
}
}
func (this *Manager) SetDB(db *DB) {
this.db = db
}
func (this *Manager) Start() {
remotelogs.Println("AGENT_MANAGER", "starting ...")
err := this.loadDB()
if err != nil {
remotelogs.Error("AGENT_MANAGER", "load database failed: "+err.Error())
return
}
// 从本地数据库中加载
err = this.Load()
if err != nil {
remotelogs.Error("AGENT_MANAGER", "load failed: "+err.Error())
}
// 先从API获取
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
} else {
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
}
}
// 定时获取
var duration = 30 * time.Minute
if Tea.IsTesting() {
duration = 30 * time.Second
}
var ticker = time.NewTicker(duration)
for range ticker.C {
err = this.LoopAll()
if err != nil {
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
}
}
}
func (this *Manager) Load() error {
var offset int64 = 0
var size int64 = 10000
for {
agentIPs, err := this.db.ListAgentIPs(offset, size)
if err != nil {
return err
}
if len(agentIPs) == 0 {
break
}
for _, agentIP := range agentIPs {
this.locker.Lock()
this.ipMap[agentIP.IP] = agentIP.AgentCode
this.locker.Unlock()
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
}
offset += size
}
return nil
}
func (this *Manager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环获取数据
func (this *Manager) Loop() (hasNext bool, err error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return false, err
}
ipsResp, err := rpcClient.ClientAgentIPRPC.ListClientAgentIPsAfterId(rpcClient.Context(), &pb.ListClientAgentIPsAfterIdRequest{
Id: this.lastId,
Size: 10000,
})
if err != nil {
return false, err
}
if len(ipsResp.ClientAgentIPs) == 0 {
return false, nil
}
for _, agentIP := range ipsResp.ClientAgentIPs {
if agentIP.ClientAgent == nil {
// 设置ID
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
continue
}
// 写入到数据库
err = this.db.InsertAgentIP(agentIP.Id, agentIP.Ip, agentIP.ClientAgent.Code)
if err != nil {
return false, err
}
// 写入Map
this.locker.Lock()
this.ipMap[agentIP.Ip] = agentIP.ClientAgent.Code
this.locker.Unlock()
// 设置ID
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
}
return true, nil
}
// AddIP 添加记录
func (this *Manager) AddIP(ip string, agentCode string) {
this.locker.Lock()
this.ipMap[ip] = agentCode
this.locker.Unlock()
}
// LookupIP 查询IP所属Agent
func (this *Manager) LookupIP(ip string) (agentCode string) {
this.locker.RLock()
defer this.locker.RUnlock()
return this.ipMap[ip]
}
// ContainsIP 检查是否有IP相关数据
func (this *Manager) ContainsIP(ip string) bool {
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.ipMap[ip]
return ok
}
func (this *Manager) loadDB() error {
var db = NewDB(Tea.Root + "/data/agents.db")
err := db.Init()
if err != nil {
return err
}
this.db = db
return nil
}

View File

@@ -0,0 +1,32 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestNewManager(t *testing.T) {
var db = agents.NewDB(Tea.Root + "/data/agents.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = agents.NewManager()
manager.SetDB(db)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
_, err = manager.Loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.LookupIP("192.168.3.100"))
}

View File

@@ -0,0 +1,133 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"net"
)
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedQueue.Start()
})
})
}
var SharedQueue = NewQueue()
type Queue struct {
c chan string // chan ip
cacheMap *IPCacheMap
}
func NewQueue() *Queue {
return &Queue{
c: make(chan string, 128),
cacheMap: NewIPCacheMap(65535),
}
}
func (this *Queue) Start() {
for ip := range this.c {
err := this.Process(ip)
if err != nil {
// 不需要上报错误
if Tea.IsTesting() {
remotelogs.Debug("SharedParseQueue", err.Error())
}
continue
}
}
}
// Push 将IP加入到处理队列
func (this *Queue) Push(ip string) {
// 是否在处理中
if this.cacheMap.Contains(ip) {
return
}
this.cacheMap.Add(ip)
// 加入到队列
select {
case this.c <- ip:
default:
}
}
// Process 处理IP
func (this *Queue) Process(ip string) error {
// 是否已经在库中
if SharedManager.ContainsIP(ip) {
return nil
}
ptr, err := this.ParseIP(ip)
if err != nil {
return err
}
if len(ptr) == 0 || ptr == "." {
return nil
}
//remotelogs.Debug("AGENT", ip+" => "+ptr)
var agentCode = this.ParsePtr(ptr)
if len(agentCode) == 0 {
return nil
}
// 加入到本地
SharedManager.AddIP(ip, agentCode)
var pbAgentIP = &pb.CreateClientAgentIPsRequest_AgentIPInfo{
AgentCode: agentCode,
Ip: ip,
Ptr: ptr,
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
_, err = rpcClient.ClientAgentIPRPC.CreateClientAgentIPs(rpcClient.Context(), &pb.CreateClientAgentIPsRequest{AgentIPs: []*pb.CreateClientAgentIPsRequest_AgentIPInfo{pbAgentIP}})
if err != nil {
return err
}
return nil
}
// ParseIP 分析IP的PTR值
func (this *Queue) ParseIP(ip string) (ptr string, err error) {
if len(ip) == 0 {
return "", nil
}
names, err := net.LookupAddr(ip)
if err != nil {
return "", err
}
if len(names) == 0 {
return "", nil
}
return names[0], nil
}
// ParsePtr 分析PTR对应的Agent
func (this *Queue) ParsePtr(ptr string) (agentCode string) {
for _, agent := range AllAgents {
if agent.Match(ptr) {
return agent.Code
}
}
return ""
}

View File

@@ -0,0 +1,76 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
"time"
)
func TestParseQueue_Process(t *testing.T) {
var queue = agents.NewQueue()
go queue.Start()
time.Sleep(1 * time.Second)
queue.Push("220.181.13.100")
time.Sleep(1 * time.Second)
}
func TestParseQueue_ParseIP(t *testing.T) {
var queue = agents.NewQueue()
for _, ip := range []string{
"192.168.1.100",
"42.120.160.1",
"42.236.10.98",
"124.115.0.100",
} {
ptr, err := queue.ParseIP(ip)
if err != nil {
t.Log(ip, "=>", err)
continue
}
t.Log(ip, "=>", ptr)
}
}
func TestParseQueue_ParsePtr(t *testing.T) {
var a = assert.NewAssertion(t)
var queue = agents.NewQueue()
for _, s := range [][]string{
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
{"crawl-66-249-71-219.googlebot.com.", "google"},
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
{"m13102.mail.163.com.", "youdao"},
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
{"93-158-161-39.spider.yandex.com.", "yandex"},
{"25.bl.bot.semrush.com.", "semrush"},
} {
a.IsTrue(queue.ParsePtr(s[0]) == s[1])
}
}
func BenchmarkQueue_ParsePtr(b *testing.B) {
var queue = agents.NewQueue()
for i := 0; i < b.N; i++ {
for _, s := range [][]string{
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
{"crawl-66-249-71-219.googlebot.com.", "google"},
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
{"m13102.mail.163.com.", "youdao"},
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
{"93-158-161-39.spider.yandex.com.", "yandex"},
{"93.158.164.218-red.dhcp.yndx.net.", "yandex"},
{"25.bl.bot.semrush.com.", "semrush"},
} {
queue.ParsePtr(s[0])
}
}
}

View File

@@ -2,47 +2,39 @@
package utils
import "bytes"
import (
"bytes"
"sync"
)
var SharedBufferPool = NewBufferPool()
// BufferPool pool for get byte slice
type BufferPool struct {
c chan *bytes.Buffer
rawPool *sync.Pool
}
// NewBufferPool 创建新对象
func NewBufferPool(maxSize int) *BufferPool {
if maxSize <= 0 {
maxSize = 1024
}
pool := &BufferPool{
c: make(chan *bytes.Buffer, maxSize),
func NewBufferPool() *BufferPool {
var pool = &BufferPool{}
pool.rawPool = &sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
return pool
}
// Get 获取一个新的Buffer
func (this *BufferPool) Get() (b *bytes.Buffer) {
select {
case b = <-this.c:
b.Reset()
default:
b = &bytes.Buffer{}
var buffer = this.rawPool.Get().(*bytes.Buffer)
if buffer.Len() > 0 {
buffer.Reset()
}
return
return buffer
}
// Put 放回一个使用过的byte slice
func (this *BufferPool) Put(b *bytes.Buffer) {
b.Reset()
select {
case this.c <- b:
default:
// 已达最大容量,则抛弃
}
}
// Size 当前的数量
func (this *BufferPool) Size() int {
return len(this.c)
this.rawPool.Put(b)
}

View File

@@ -0,0 +1,47 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils_test
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"strings"
"testing"
)
func TestNewBufferPool(t *testing.T) {
var pool = utils.NewBufferPool()
var b = pool.Get()
b.WriteString("Hello, World")
t.Log(b.String())
pool.Put(b)
t.Log(b.String())
b = pool.Get()
t.Log(b.String())
}
func BenchmarkNewBufferPool1(b *testing.B) {
var data = []byte(strings.Repeat("Hello", 1024))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var buffer = &bytes.Buffer{}
buffer.Write(data)
}
})
}
func BenchmarkNewBufferPool2(b *testing.B) {
var pool = utils.NewBufferPool()
var data = []byte(strings.Repeat("Hello", 1024))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var buffer = pool.Get()
buffer.Write(data)
pool.Put(buffer)
}
})
}

View File

@@ -39,6 +39,7 @@ func init() {
}
type ClockManager struct {
lastFailAt int64
}
func NewClockManager() *ClockManager {
@@ -51,7 +52,13 @@ func (this *ClockManager) Start() {
for range ticker.C {
err := this.Sync()
if err != nil {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
var currentTimestamp = time.Now().Unix()
// 每天只提醒一次错误
if currentTimestamp-this.lastFailAt > 86400 {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
this.lastFailAt = currentTimestamp
}
}
}
}
@@ -72,6 +79,18 @@ func (this *ClockManager) Sync() error {
return nil
}
// check chrony
if config.CheckChrony {
chronycExe, err := exec.LookPath("chronyc")
if err == nil && len(chronycExe) > 0 {
var chronyCmd = executils.NewTimeoutCmd(3*time.Second, chronycExe, "tracking")
err = chronyCmd.Run()
if err == nil {
return nil
}
}
}
var server = config.Server
if len(server) == 0 {
server = "pool.ntp.org"
@@ -118,7 +137,7 @@ func (this *ClockManager) syncNtpdate(ntpdate string, server string) error {
return nil
}
// 参考自https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
// ReadServer 参考自https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
func (this *ClockManager) ReadServer(server string) (time.Time, error) {
conn, err := net.Dial("udp", server+":123")
if err != nil {

View File

@@ -6,7 +6,7 @@ import (
"regexp"
)
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
var RegexpDigitNumber = regexp.MustCompile(`^\d+$`)
func Get(object interface{}, keys []string) interface{} {
if len(keys) == 0 {

View File

@@ -7,7 +7,7 @@ import (
"github.com/iwind/TeaGo/maps"
)
func MapToObject(m maps.Map, ptr interface{}) error {
func MapToObject(m maps.Map, ptr any) error {
if m == nil {
return nil
}
@@ -18,7 +18,7 @@ func MapToObject(m maps.Map, ptr interface{}) error {
return json.Unmarshal(mJSON, ptr)
}
func ObjectToMap(ptr interface{}) (maps.Map, error) {
func ObjectToMap(ptr any) (maps.Map, error) {
if ptr == nil {
return maps.Map{}, nil
}
@@ -33,3 +33,12 @@ func ObjectToMap(ptr interface{}) (maps.Map, error) {
}
return result, nil
}
func Copy(destPtr any, srcPtr any) error {
data, err := json.Marshal(srcPtr)
if err != nil {
return err
}
err = json.Unmarshal(data, destPtr)
return err
}

View File

@@ -3,11 +3,12 @@
package jsonutils
import (
"bytes"
"encoding/json"
"testing"
)
func PrintT(obj interface{}, t *testing.T) {
func PrintT(obj any, t *testing.T) {
data, err := json.MarshalIndent(obj, "", " ")
if err != nil {
t.Log(err)
@@ -15,3 +16,17 @@ func PrintT(obj interface{}, t *testing.T) {
t.Log(string(data))
}
}
func Equal(obj1 any, obj2 any) bool {
data1, err := json.Marshal(obj1)
if err != nil {
return false
}
data2, err := json.Marshal(obj2)
if err != nil {
return false
}
return bytes.Equal(data1, data2)
}

View File

@@ -0,0 +1,26 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package jsonutils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"testing"
)
func TestEqual(t *testing.T) {
var a = assert.NewAssertion(t)
{
var m1 = maps.Map{"a": 1, "b2": true}
var m2 = maps.Map{"b2": true, "a": 1}
a.IsTrue(jsonutils.Equal(m1, m2))
}
{
var m1 = maps.Map{"a": 1, "b2": true, "c": nil}
var m2 = maps.Map{"b2": true, "a": 1}
a.IsFalse(jsonutils.Equal(m1, m2))
}
}

View File

@@ -0,0 +1,135 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package readers
import (
"errors"
"github.com/iwind/TeaGo/types"
"io"
"sync"
)
type concurrentSubReader struct {
main *ConcurrentReaderList
index int
}
func (this *concurrentSubReader) Read(p []byte) (n int, err error) {
n, err = this.main.readIndex(p, this.index)
this.index++
return
}
func (this *concurrentSubReader) Close() error {
this.main.removeSubReader(this)
err := this.main.Close()
if err != nil {
return err
}
return nil
}
// ConcurrentReaderList
// TODO 动态调整 pieces = pieces[minPieceIndex:] 以节约内存
type ConcurrentReaderList struct {
locker sync.RWMutex
readLocker sync.Mutex
mainReader io.ReadCloser
subReaderMap map[*concurrentSubReader]bool
pieces [][]byte
lastErr error
}
func NewConcurrentReaderList(mainReader io.ReadCloser) *ConcurrentReaderList {
return &ConcurrentReaderList{
mainReader: mainReader,
subReaderMap: map[*concurrentSubReader]bool{},
}
}
func (this *ConcurrentReaderList) NewReader() io.ReadCloser {
var subReader = &concurrentSubReader{
main: this,
}
this.locker.Lock()
this.subReaderMap[subReader] = true
this.locker.Unlock()
return subReader
}
func (this *ConcurrentReaderList) read(p []byte) (n int, err error) {
n, err = this.mainReader.Read(p)
this.lastErr = err
if n > 0 {
var piece = make([]byte, n)
copy(piece, p[:n])
this.locker.Lock()
this.pieces = append(this.pieces, piece)
this.locker.Unlock()
}
return
}
func (this *ConcurrentReaderList) readIndex(p []byte, index int) (n int, err error) {
// 如果已经有数据
this.locker.RLock()
var countPieces = len(this.pieces)
if index < countPieces {
var piece = this.pieces[index]
this.locker.RUnlock()
var pn = len(piece)
if len(p) < pn {
err = errors.New("invalid buffer length '" + types.String(len(p)) + "' vs '" + types.String(len(piece)) + "'")
return
}
n = pn
copy(p, piece)
return
}
this.locker.RUnlock()
if this.lastErr != nil {
return 0, this.lastErr
}
// 如果没有数据,则读取之
this.readLocker.Lock()
// 再次检查数据是否已更新
this.locker.RLock()
if len(this.pieces) > countPieces || this.lastErr != nil {
this.locker.RUnlock()
this.readLocker.Unlock()
return this.readIndex(p, index)
}
this.locker.RUnlock()
// 从原始Reader中读取
n, err = this.read(p)
this.readLocker.Unlock()
if n > 0 {
// 重新尝试
return this.readIndex(p, index)
}
return
}
func (this *ConcurrentReaderList) removeSubReader(subReader *concurrentSubReader) {
this.locker.Lock()
delete(this.subReaderMap, subReader)
this.locker.Unlock()
}
func (this *ConcurrentReaderList) Close() error {
this.locker.Lock()
if len(this.subReaderMap) == 0 {
this.locker.Unlock()
return this.mainReader.Close()
}
this.locker.Unlock()
return nil
}

View File

@@ -0,0 +1,81 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package readers_test
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
"io"
"sync"
"testing"
"time"
)
type testReader struct {
t *testing.T
rawReader io.Reader
}
func (this *testReader) Read(p []byte) (n int, err error) {
time.Sleep(1 * time.Second) // 延迟
return this.rawReader.Read(p)
}
func (this *testReader) Close() error {
this.t.Log("close")
return nil
}
func TestNewConcurrentReader(t *testing.T) {
var originBuffer = &bytes.Buffer{}
originBuffer.Write([]byte("0123456789_hello_world"))
var originLength = originBuffer.Len()
var concurrentReader = readers.NewConcurrentReaderList(&testReader{
t: t,
rawReader: originBuffer,
})
var threads = 32
var wg = &sync.WaitGroup{}
wg.Add(threads)
var locker = &sync.Mutex{}
var m = map[int][]byte{} // i => []byte
for i := 0; i < threads; i++ {
go func(i int) {
defer wg.Done()
var reader = concurrentReader.NewReader()
var buf = make([]byte, 4)
for {
n, err := reader.Read(buf)
if n > 0 {
locker.Lock()
m[i] = append(m[i], buf[:n]...)
locker.Unlock()
//t.Log(i, string(buf[:n]))
}
if err != nil {
if err == io.EOF {
break
}
t.Log("ERROR:", err)
}
}
_ = reader.Close()
}(i)
}
wg.Wait()
for i, b := range m {
if len(b) != originLength {
t.Fatal("ERROR:", i, string(b))
}
t.Log(i, string(b))
}
}

View File

@@ -42,8 +42,8 @@ func ToValidUTF8string(v string) string {
return strings.ToValidUTF8(v, "")
}
// ContainsSameStrings 检查两个字符串slice内容是否一致
func ContainsSameStrings(s1 []string, s2 []string) bool {
// EqualStrings 检查两个字符串slice内容是否一致
func EqualStrings(s1 []string, s2 []string) bool {
if len(s1) != len(s2) {
return false
}

View File

@@ -59,9 +59,9 @@ func TestFormatAddressList(t *testing.T) {
func TestContainsSameStrings(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
a.IsFalse(utils.EqualStrings([]string{"a"}, []string{"b"}))
a.IsFalse(utils.EqualStrings([]string{"a", "b"}, []string{"b"}))
a.IsFalse(utils.EqualStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
a.IsTrue(utils.EqualStrings([]string{"a", "b"}, []string{"a", "b"}))
a.IsTrue(utils.EqualStrings([]string{"a", "b"}, []string{"b", "a"}))
}

View File

@@ -3,6 +3,7 @@
package writers
import (
"context"
"github.com/iwind/TeaGo/types"
"io"
"time"
@@ -11,6 +12,7 @@ import (
// RateLimitWriter 限速写入
type RateLimitWriter struct {
rawWriter io.WriteCloser
ctx context.Context
rateBytes int
@@ -18,9 +20,10 @@ type RateLimitWriter struct {
before time.Time
}
func NewRateLimitWriter(rawWriter io.WriteCloser, rateBytes int64) io.WriteCloser {
func NewRateLimitWriter(ctx context.Context, rawWriter io.WriteCloser, rateBytes int64) io.WriteCloser {
return &RateLimitWriter{
rawWriter: rawWriter,
ctx: ctx,
rateBytes: types.Int(rateBytes),
before: time.Now(),
}
@@ -71,6 +74,14 @@ func (this *RateLimitWriter) write(p []byte) (n int, err error) {
n, err = this.rawWriter.Write(p)
if err == nil {
select {
case <-this.ctx.Done():
err = io.EOF
return
default:
}
this.written += n
if this.written >= this.rateBytes {

Some files were not shown because too many files have changed in this diff Show More