Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7763f26249 | ||
|
|
b7ae10e2d0 | ||
|
|
e54eddc961 | ||
|
|
93db9d4926 | ||
|
|
8c1af3e699 | ||
|
|
53c74553bc | ||
|
|
3eb9cade0e | ||
|
|
eeee3da941 | ||
|
|
6af59e0bd0 | ||
|
|
012233baf2 | ||
|
|
ac069fd7f3 | ||
|
|
40cb1916c2 | ||
|
|
749e0bd0b3 | ||
|
|
ae1a9abf5e | ||
|
|
7fc0394f10 | ||
|
|
4a169a2dbd | ||
|
|
44d8afeda8 | ||
|
|
6a0547abec | ||
|
|
6d002e2822 | ||
|
|
04271d77c2 | ||
|
|
5a6ead1dd7 | ||
|
|
9ac7b9b2c0 | ||
|
|
7ec916c1fb | ||
|
|
fadc580dff | ||
|
|
fb9e9fb94b | ||
|
|
9c6e4bb8c1 | ||
|
|
97b04777bc | ||
|
|
4daeca912a | ||
|
|
7e43324b53 | ||
|
|
b9b8472c3a | ||
|
|
6858380bb4 | ||
|
|
568ecadfc6 | ||
|
|
8210ece2b7 | ||
|
|
f8160e35b9 | ||
|
|
9e4a1212d2 | ||
|
|
6f52cffabd | ||
|
|
c546b9fc7d | ||
|
|
f7b961d256 | ||
|
|
71cbb2d695 | ||
|
|
2063015eeb | ||
|
|
87cc43b2e0 | ||
|
|
9812883b61 | ||
|
|
068c20e1b9 | ||
|
|
17d883a2de | ||
|
|
083bbb1460 | ||
|
|
4f7b9f4fc6 |
@@ -1,4 +1,4 @@
|
||||
rpc:
|
||||
endpoints: [ ${endpoints} ]
|
||||
nodeId: "${nodeId}"
|
||||
secret: "${nodeSecret}"
|
||||
endpoints: [ "" ]
|
||||
nodeId: ""
|
||||
secret: ""
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -60,6 +61,33 @@ func main() {
|
||||
node := nodes.NewNode()
|
||||
node.Start()
|
||||
})
|
||||
app.On("trackers", func() {
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
reply, err := sock.Send(&gosock.Command{Code: "trackers"})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
} else {
|
||||
labelsMap, ok := reply.Params["labels"]
|
||||
if ok {
|
||||
labels, ok := labelsMap.(map[string]interface{})
|
||||
if ok {
|
||||
if len(labels) == 0 {
|
||||
fmt.Println("no labels yet")
|
||||
} else {
|
||||
var labelNames = []string{}
|
||||
for label := range labels {
|
||||
labelNames = append(labelNames, label)
|
||||
}
|
||||
sort.Strings(labelNames)
|
||||
|
||||
for _, labelName := range labelNames {
|
||||
fmt.Println(labelName + ": " + fmt.Sprintf("%.6f", labels[labelName]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
app.Run(func() {
|
||||
node := nodes.NewNode()
|
||||
node.Start()
|
||||
|
||||
6
go.mod
6
go.mod
@@ -16,14 +16,16 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/golang/protobuf v1.5.2
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060
|
||||
github.com/iwind/TeaGo v0.0.0-20211026123858-7de7a21cad24
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11
|
||||
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3
|
||||
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
github.com/mattn/go-sqlite3 v1.14.9
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/mssola/user_agent v0.5.2
|
||||
github.com/pires/go-proxyproto v0.6.1
|
||||
github.com/shirou/gopsutil v3.21.5+incompatible
|
||||
|
||||
10
go.sum
10
go.sum
@@ -78,6 +78,8 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060 h1:qdLtK4PDXxk2vMKkTWl5Fl9xqYuRCukzWAgJbLHdfOo=
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/TeaGo v0.0.0-20211026123858-7de7a21cad24 h1:1cGulkD2SNJJRok5OKwyhP/Ddm+PgSWKOupn0cR36/A=
|
||||
github.com/iwind/TeaGo v0.0.0-20211026123858-7de7a21cad24/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 h1:DaQjoWZhLNxjhIXedVg4/vFEtHkZhK4IjIwsWdyzBLg=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11/go.mod h1:JtbX20untAjUVjZs1ZBtq80f5rJWvwtQNRL6EnuYRnY=
|
||||
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 h1:aBSonas7vFcgTj9u96/bWGILGv1ZbUSTLiOzcI1ZT6c=
|
||||
@@ -85,6 +87,8 @@ github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3/go.mod h1:H5Q7SXwbx3a
|
||||
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8 h1:AojsHz9Es9B3He2MQQxeRq3TyD//o9huxUo7r1wh44g=
|
||||
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8/go.mod h1:QJBY2txYhLMzwLO29iB5ujDJ3s3V7DsZ582nw4Ss+tM=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -96,12 +100,18 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible h1:1qp9iks+69h7IGLazAplzS9Ca14HAxuD5c0rbFdPGy4=
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible/go.mod h1:+ZBN7PBoh5gG6/y0ZQ85vJDBe21WnfbRrQQwTfliJJI=
|
||||
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
|
||||
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
|
||||
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mssola/user_agent v0.5.2 h1:CZkTUahjL1+OcZ5zv3kZr8QiJ8jy2H08vZIEkBeRbxo=
|
||||
github.com/mssola/user_agent v0.5.2/go.mod h1:TTPno8LPY3wAIEKRpAtkdMT0f8SE24pLRGPahjCH4uw=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"github.com/iwind/TeaGo/utils/time"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LogWriter struct {
|
||||
@@ -38,18 +38,20 @@ func (this *LogWriter) Init() {
|
||||
}
|
||||
|
||||
func (this *LogWriter) Write(message string) {
|
||||
// 文件和行号
|
||||
var callDepth = 2
|
||||
var file string
|
||||
var line int
|
||||
var ok bool
|
||||
_, file, line, ok = runtime.Caller(callDepth)
|
||||
if ok {
|
||||
file = filepath.Base(file)
|
||||
}
|
||||
|
||||
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
|
||||
if backgroundEnv != "on" {
|
||||
// 文件和行号
|
||||
var file string
|
||||
var line int
|
||||
if Tea.IsTesting() {
|
||||
var callDepth = 3
|
||||
var ok bool
|
||||
_, file, line, ok = runtime.Caller(callDepth)
|
||||
if ok {
|
||||
file = this.packagePath(file)
|
||||
}
|
||||
}
|
||||
|
||||
if len(file) > 0 {
|
||||
log.Println(message + " (" + file + ":" + strconv.Itoa(line) + ")")
|
||||
} else {
|
||||
@@ -70,3 +72,11 @@ func (this *LogWriter) Close() {
|
||||
_ = this.fileAppender.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *LogWriter) packagePath(path string) string {
|
||||
var pieces = strings.Split(path, "/")
|
||||
if len(pieces) >= 2 {
|
||||
return strings.Join(pieces[len(pieces)-2:], "/")
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
10
internal/caches/hot_item.go
Normal file
10
internal/caches/hot_item.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package caches
|
||||
|
||||
type HotItem struct {
|
||||
Key string
|
||||
ExpiresAt int64
|
||||
Hits uint32
|
||||
Status int
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package caches
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ItemType = int
|
||||
@@ -11,6 +12,12 @@ const (
|
||||
ItemTypeMemory ItemType = 2
|
||||
)
|
||||
|
||||
// 计算当前周
|
||||
// 不要用YW,因为需要计算两周是否临近
|
||||
func currentWeek() int32 {
|
||||
return int32(time.Now().Unix() / 86400)
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
Type ItemType `json:"type"`
|
||||
Key string `json:"key"`
|
||||
@@ -20,6 +27,10 @@ type Item struct {
|
||||
MetaSize int64 `json:"metaSize"`
|
||||
Host string `json:"host"` // 主机名
|
||||
ServerId int64 `json:"serverId"` // 服务ID
|
||||
|
||||
Week1Hits int64 `json:"week1Hits"`
|
||||
Week2Hits int64 `json:"week2Hits"`
|
||||
Week int32 `json:"week"`
|
||||
}
|
||||
|
||||
func (this *Item) IsExpired() bool {
|
||||
@@ -27,9 +38,23 @@ func (this *Item) IsExpired() bool {
|
||||
}
|
||||
|
||||
func (this *Item) TotalSize() int64 {
|
||||
return this.Size() + this.MetaSize + int64(len(this.Key)) + int64(len(this.Host)) + 64
|
||||
return this.Size() + this.MetaSize + int64(len(this.Key)) + int64(len(this.Host))
|
||||
}
|
||||
|
||||
func (this *Item) Size() int64 {
|
||||
return this.HeaderSize + this.BodySize
|
||||
}
|
||||
|
||||
func (this *Item) IncreaseHit(week int32) {
|
||||
if this.Week == week {
|
||||
this.Week2Hits++
|
||||
} else {
|
||||
if week-this.Week == 1 {
|
||||
this.Week1Hits = this.Week2Hits
|
||||
} else {
|
||||
this.Week1Hits = 0
|
||||
}
|
||||
this.Week2Hits = 1
|
||||
this.Week = week
|
||||
}
|
||||
}
|
||||
|
||||
82
internal/caches/item_test.go
Normal file
82
internal/caches/item_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package caches
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestItem_IncreaseHit(t *testing.T) {
|
||||
var week = currentWeek()
|
||||
|
||||
var item = &Item{}
|
||||
//item.Week = 2704
|
||||
item.Week2Hits = 100
|
||||
item.IncreaseHit(week)
|
||||
t.Log("week:", item.Week, "week1:", item.Week1Hits, "week2:", item.Week2Hits)
|
||||
|
||||
item.IncreaseHit(week)
|
||||
t.Log("week:", item.Week, "week1:", item.Week1Hits, "week2:", item.Week2Hits)
|
||||
}
|
||||
|
||||
func TestItems_Memory(t *testing.T) {
|
||||
var stat = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat)
|
||||
var memory1 = stat.HeapInuse
|
||||
|
||||
var items = []*Item{}
|
||||
for i := 0; i < 10_000_000; i++ {
|
||||
items = append(items, &Item{
|
||||
Key: types.String(i),
|
||||
})
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(stat)
|
||||
var memory2 = stat.HeapInuse
|
||||
|
||||
t.Log(memory1, memory2, (memory2-memory1)/1024/1024, "M")
|
||||
|
||||
var weekItems = make(map[string]*Item, 10_000_000)
|
||||
|
||||
for _, item := range items {
|
||||
weekItems[item.Key] = item
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(stat)
|
||||
var memory3 = stat.HeapInuse
|
||||
t.Log(memory2, memory3, (memory3-memory2)/1024/1024, "M")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
t.Log(len(items), len(weekItems))
|
||||
}
|
||||
|
||||
func TestItems_Memory2(t *testing.T) {
|
||||
var stat = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat)
|
||||
var memory1 = stat.HeapInuse
|
||||
|
||||
var items = map[int32]map[string]bool{}
|
||||
for i := 0; i < 10_000_000; i++ {
|
||||
var week = int32((time.Now().Unix() - int64(86400*rands.Int(0, 300))) / (86400 * 7))
|
||||
m, ok := items[week]
|
||||
if !ok {
|
||||
m = map[string]bool{}
|
||||
items[week] = m
|
||||
}
|
||||
m[types.String(int64(i)*1_000_000)] = true
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(stat)
|
||||
var memory2 = stat.HeapInuse
|
||||
|
||||
t.Log(memory1, memory2, (memory2-memory1)/1024/1024, "M")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
for w, i := range items {
|
||||
t.Log(w, len(i))
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -24,6 +25,7 @@ type FileList struct {
|
||||
onAdd func(item *Item)
|
||||
onRemove func(item *Item)
|
||||
|
||||
// cacheItems
|
||||
existsByHashStmt *sql.Stmt // 根据hash检查是否存在
|
||||
insertStmt *sql.Stmt // 写入数据
|
||||
selectByHashStmt *sql.Stmt // 使用hash查询数据
|
||||
@@ -32,8 +34,15 @@ type FileList struct {
|
||||
purgeStmt *sql.Stmt // 清理
|
||||
deleteAllStmt *sql.Stmt // 删除所有数据
|
||||
|
||||
// hits
|
||||
insertHitStmt *sql.Stmt // 写入数据
|
||||
increaseHitStmt *sql.Stmt // 增加点击量
|
||||
deleteHitByHashStmt *sql.Stmt // 根据hash删除数据
|
||||
lfuHitsStmt *sql.Stmt // 读取老的数据
|
||||
|
||||
oldTables []string
|
||||
itemsTableName string
|
||||
hitsTableName string
|
||||
|
||||
isClosed bool
|
||||
|
||||
@@ -59,6 +68,7 @@ func (this *FileList) Init() error {
|
||||
}
|
||||
|
||||
this.itemsTableName = "cacheItems_v2"
|
||||
this.hitsTableName = "hits"
|
||||
|
||||
var dir = this.dir
|
||||
if dir == "/" {
|
||||
@@ -141,6 +151,23 @@ func (this *FileList) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertHitStmt, err = this.db.Prepare(`INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?)`)
|
||||
|
||||
this.increaseHitStmt, err = this.db.Prepare(`INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?) ON CONFLICT("hash") DO UPDATE SET "week1Hits"=IIF("week"=?, "week1Hits", "week2Hits"), "week2Hits"=IIF("week"=?, "week2Hits"+1, 1), "week"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteHitByHashStmt, err = this.db.Prepare(`DELETE FROM "` + this.hitsTableName + `" WHERE "hash"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.lfuHitsStmt, err = this.db.Prepare(`SELECT "hash" FROM "` + this.hitsTableName + `" ORDER BY "week" ASC, "week1Hits"+"week2Hits" ASC LIMIT ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -159,6 +186,11 @@ func (this *FileList) Add(hash string, item *Item) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = this.insertHitStmt.Exec(hash, timeutil.Format("YW"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.total, 1)
|
||||
|
||||
if this.onAdd != nil {
|
||||
@@ -253,6 +285,11 @@ func (this *FileList) Remove(hash string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = this.deleteHitByHashStmt.Exec(hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.total, -1)
|
||||
|
||||
if this.onRemove != nil {
|
||||
@@ -265,9 +302,9 @@ func (this *FileList) Remove(hash string) error {
|
||||
// Purge 清理过期的缓存
|
||||
// count 每次遍历的最大数量,控制此数字可以保证每次清理的时候不用花太多时间
|
||||
// callback 每次发现过期key的调用
|
||||
func (this *FileList) Purge(count int, callback func(hash string) error) error {
|
||||
func (this *FileList) Purge(count int, callback func(hash string) error) (int, error) {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if count <= 0 {
|
||||
@@ -275,11 +312,56 @@ func (this *FileList) Purge(count int, callback func(hash string) error) error {
|
||||
}
|
||||
|
||||
rows, err := this.purgeStmt.Query(time.Now().Unix(), count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
hashStrings := []string{}
|
||||
var countFound = 0
|
||||
for rows.Next() {
|
||||
var hash string
|
||||
err = rows.Scan(&hash)
|
||||
if err != nil {
|
||||
_ = rows.Close()
|
||||
return 0, err
|
||||
}
|
||||
hashStrings = append(hashStrings, hash)
|
||||
countFound++
|
||||
}
|
||||
_ = rows.Close() // 不能使用defer,防止读写冲突
|
||||
|
||||
// 不在 rows.Next() 循环中操作是为了避免死锁
|
||||
for _, hash := range hashStrings {
|
||||
err = this.Remove(hash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = callback(hash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return countFound, nil
|
||||
}
|
||||
|
||||
func (this *FileList) PurgeLFU(count int, callback func(hash string) error) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if count <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := this.lfuHitsStmt.Query(count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashStrings := []string{}
|
||||
var countFound = 0
|
||||
for rows.Next() {
|
||||
var hash string
|
||||
err = rows.Scan(&hash)
|
||||
@@ -288,6 +370,7 @@ func (this *FileList) Purge(count int, callback func(hash string) error) error {
|
||||
return err
|
||||
}
|
||||
hashStrings = append(hashStrings, hash)
|
||||
countFound++
|
||||
}
|
||||
_ = rows.Close() // 不能使用defer,防止读写冲突
|
||||
|
||||
@@ -303,7 +386,6 @@ func (this *FileList) Purge(count int, callback func(hash string) error) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -347,6 +429,13 @@ func (this *FileList) Count() (int64, error) {
|
||||
return atomic.LoadInt64(&this.total), nil
|
||||
}
|
||||
|
||||
// IncreaseHit 增加点击量
|
||||
func (this *FileList) IncreaseHit(hash string) error {
|
||||
var week = timeutil.Format("YW")
|
||||
_, err := this.increaseHitStmt.Exec(hash, week, week, week, week)
|
||||
return err
|
||||
}
|
||||
|
||||
// OnAdd 添加事件
|
||||
func (this *FileList) OnAdd(f func(item *Item)) {
|
||||
this.onAdd = f
|
||||
@@ -371,6 +460,11 @@ func (this *FileList) Close() error {
|
||||
_ = this.purgeStmt.Close()
|
||||
_ = this.deleteAllStmt.Close()
|
||||
|
||||
_ = this.insertHitStmt.Close()
|
||||
_ = this.increaseHitStmt.Close()
|
||||
_ = this.deleteHitByHashStmt.Close()
|
||||
_ = this.lfuHitsStmt.Close()
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
return nil
|
||||
@@ -378,7 +472,8 @@ func (this *FileList) Close() error {
|
||||
|
||||
// 初始化
|
||||
func (this *FileList) initTables(db *sql.DB, times int) error {
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemsTableName + `" (
|
||||
{
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemsTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"hash" varchar(32),
|
||||
"key" varchar(1024),
|
||||
@@ -411,17 +506,46 @@ ON "` + this.itemsTableName + `" (
|
||||
"serverId" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
// 尝试删除重建
|
||||
if times < 3 {
|
||||
_, dropErr := db.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
|
||||
if dropErr == nil {
|
||||
return this.initTables(db, times+1)
|
||||
if err != nil {
|
||||
// 尝试删除重建
|
||||
if times < 3 {
|
||||
_, dropErr := db.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
|
||||
if dropErr == nil {
|
||||
return this.initTables(db, times+1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
{
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.hitsTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"hash" varchar(32),
|
||||
"week1Hits" integer DEFAULT 0,
|
||||
"week2Hits" integer DEFAULT 0,
|
||||
"week" varchar(6)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "hits_hash"
|
||||
ON "` + this.hitsTableName + `" (
|
||||
"hash" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
// 尝试删除重建
|
||||
if times < 3 {
|
||||
_, dropErr := db.Exec(`DROP TABLE "` + this.hitsTableName + `"`)
|
||||
if dropErr == nil {
|
||||
return this.initTables(db, times+1)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ package caches
|
||||
import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -184,7 +185,7 @@ func TestFileList_Purge(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = list.Purge(2, func(hash string) error {
|
||||
_, err = list.Purge(2, func(hash string) error {
|
||||
t.Log(hash)
|
||||
return nil
|
||||
})
|
||||
@@ -257,6 +258,50 @@ func TestFileList_Conflict(t *testing.T) {
|
||||
t.Log("after exists")
|
||||
}
|
||||
|
||||
func TestFileList_IIF(t *testing.T) {
|
||||
list := NewFileList(Tea.Root + "/data").(*FileList)
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := list.db.Query("SELECT IIF(0, 2, 3)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
if rows.Next() {
|
||||
var result int
|
||||
err = rows.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("result:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileList_IncreaseHit(t *testing.T) {
|
||||
list := NewFileList(Tea.Root + "/data")
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
for i := 0; i < 1000_000; i++ {
|
||||
err = list.IncreaseHit(stringutil.Md5("abc" + types.String(i)))
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func BenchmarkFileList_Exist(b *testing.B) {
|
||||
list := NewFileList(Tea.Root + "/data")
|
||||
err := list.Init()
|
||||
|
||||
@@ -22,7 +22,10 @@ type ListInterface interface {
|
||||
Remove(hash string) error
|
||||
|
||||
// Purge 清理过期数据
|
||||
Purge(count int, callback func(hash string) error) error
|
||||
Purge(count int, callback func(hash string) error) (int, error)
|
||||
|
||||
// PurgeLFU 清理LFU数据
|
||||
PurgeLFU(count int, callback func(hash string) error) error
|
||||
|
||||
// CleanAll 清除所有缓存
|
||||
CleanAll() error
|
||||
@@ -41,4 +44,7 @@ type ListInterface interface {
|
||||
|
||||
// Close 关闭
|
||||
Close() error
|
||||
|
||||
// IncreaseHit 增加点击量
|
||||
IncreaseHit(hash string) error
|
||||
}
|
||||
|
||||
@@ -5,12 +5,19 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MemoryList 内存缓存列表管理
|
||||
type MemoryList struct {
|
||||
count int64
|
||||
|
||||
itemMaps map[string]map[string]*Item // prefix => { hash => item }
|
||||
|
||||
weekItemMaps map[int32]map[string]bool // week => { hash => true }
|
||||
minWeek int32
|
||||
|
||||
prefixes []string
|
||||
locker sync.RWMutex
|
||||
onAdd func(item *Item)
|
||||
@@ -21,7 +28,9 @@ type MemoryList struct {
|
||||
|
||||
func NewMemoryList() ListInterface {
|
||||
return &MemoryList{
|
||||
itemMaps: map[string]map[string]*Item{},
|
||||
itemMaps: map[string]map[string]*Item{},
|
||||
weekItemMaps: map[int32]map[string]bool{},
|
||||
minWeek: currentWeek(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,11 +52,19 @@ func (this *MemoryList) Reset() error {
|
||||
for key := range this.itemMaps {
|
||||
this.itemMaps[key] = map[string]*Item{}
|
||||
}
|
||||
this.weekItemMaps = map[int32]map[string]bool{}
|
||||
this.locker.Unlock()
|
||||
|
||||
atomic.StoreInt64(&this.count, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *MemoryList) Add(hash string, item *Item) error {
|
||||
if item.Week == 0 {
|
||||
item.Week = currentWeek()
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
|
||||
prefix := this.prefix(hash)
|
||||
@@ -60,9 +77,20 @@ func (this *MemoryList) Add(hash string, item *Item) error {
|
||||
// 先删除,为了可以正确触发统计
|
||||
oldItem, ok := itemMap[hash]
|
||||
if ok {
|
||||
// 从week map中删除
|
||||
if oldItem.Week > 0 {
|
||||
wm, ok := this.weekItemMaps[oldItem.Week]
|
||||
if ok {
|
||||
delete(wm, hash)
|
||||
}
|
||||
}
|
||||
|
||||
// 回调
|
||||
if this.onRemove != nil {
|
||||
this.onRemove(oldItem)
|
||||
}
|
||||
} else {
|
||||
atomic.AddInt64(&this.count, 1)
|
||||
}
|
||||
|
||||
// 添加
|
||||
@@ -71,6 +99,15 @@ func (this *MemoryList) Add(hash string, item *Item) error {
|
||||
}
|
||||
|
||||
itemMap[hash] = item
|
||||
|
||||
// week map
|
||||
wm, ok := this.weekItemMaps[item.Week]
|
||||
if ok {
|
||||
wm[hash] = true
|
||||
} else {
|
||||
this.weekItemMaps[item.Week] = map[string]bool{hash: true}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
@@ -122,7 +159,17 @@ func (this *MemoryList) Remove(hash string) error {
|
||||
if this.onRemove != nil {
|
||||
this.onRemove(item)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.count, -1)
|
||||
delete(itemMap, hash)
|
||||
|
||||
// week map
|
||||
if item.Week > 0 {
|
||||
wm, ok := this.weekItemMaps[item.Week]
|
||||
if ok {
|
||||
delete(wm, hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
@@ -132,7 +179,7 @@ func (this *MemoryList) Remove(hash string) error {
|
||||
// Purge 清理过期的缓存
|
||||
// count 每次遍历的最大数量,控制此数字可以保证每次清理的时候不用花太多时间
|
||||
// callback 每次发现过期key的调用
|
||||
func (this *MemoryList) Purge(count int, callback func(hash string) error) error {
|
||||
func (this *MemoryList) Purge(count int, callback func(hash string) error) (int, error) {
|
||||
this.locker.Lock()
|
||||
deletedHashList := []string{}
|
||||
|
||||
@@ -146,8 +193,9 @@ func (this *MemoryList) Purge(count int, callback func(hash string) error) error
|
||||
itemMap, ok := this.itemMaps[prefix]
|
||||
if !ok {
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
var countFound = 0
|
||||
for hash, item := range itemMap {
|
||||
if count <= 0 {
|
||||
break
|
||||
@@ -157,14 +205,100 @@ func (this *MemoryList) Purge(count int, callback func(hash string) error) error
|
||||
if this.onRemove != nil {
|
||||
this.onRemove(item)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.count, -1)
|
||||
delete(itemMap, hash)
|
||||
deletedHashList = append(deletedHashList, hash)
|
||||
|
||||
// week map
|
||||
if item.Week > 0 {
|
||||
wm, ok := this.weekItemMaps[item.Week]
|
||||
if ok {
|
||||
delete(wm, hash)
|
||||
}
|
||||
}
|
||||
|
||||
countFound++
|
||||
}
|
||||
|
||||
count--
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 执行外部操作
|
||||
for _, hash := range deletedHashList {
|
||||
if callback != nil {
|
||||
err := callback(hash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return countFound, nil
|
||||
}
|
||||
|
||||
func (this *MemoryList) PurgeLFU(count int, callback func(hash string) error) error {
|
||||
if count <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var week = currentWeek()
|
||||
if this.minWeek > week {
|
||||
this.minWeek = week
|
||||
}
|
||||
|
||||
var deletedHashList = []string{}
|
||||
|
||||
Loop:
|
||||
for w := this.minWeek; w <= week; w++ {
|
||||
this.minWeek = w
|
||||
|
||||
this.locker.Lock()
|
||||
wm, ok := this.weekItemMaps[w]
|
||||
if ok {
|
||||
var wc = len(wm)
|
||||
if wc == 0 {
|
||||
delete(this.weekItemMaps, w)
|
||||
} else {
|
||||
if wc <= count {
|
||||
delete(this.weekItemMaps, w)
|
||||
}
|
||||
|
||||
// TODO 未来支持按照点击量排序
|
||||
for hash := range wm {
|
||||
count--
|
||||
|
||||
if count < 0 {
|
||||
this.locker.Unlock()
|
||||
break Loop
|
||||
}
|
||||
|
||||
delete(wm, hash)
|
||||
|
||||
itemMap, ok := this.itemMaps[this.prefix(hash)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
item, ok := itemMap[hash]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if this.onRemove != nil {
|
||||
this.onRemove(item)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.count, -1)
|
||||
delete(itemMap, hash)
|
||||
deletedHashList = append(deletedHashList, hash)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
delete(this.weekItemMaps, w)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// 执行外部操作
|
||||
for _, hash := range deletedHashList {
|
||||
if callback != nil {
|
||||
@@ -174,6 +308,7 @@ func (this *MemoryList) Purge(count int, callback func(hash string) error) error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -206,13 +341,8 @@ func (this *MemoryList) Stat(check func(hash string) bool) (*Stat, error) {
|
||||
|
||||
// Count 总数量
|
||||
func (this *MemoryList) Count() (int64, error) {
|
||||
this.locker.RLock()
|
||||
var count = 0
|
||||
for _, itemMap := range this.itemMaps {
|
||||
count += len(itemMap)
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return int64(count), nil
|
||||
var count = atomic.LoadInt64(&this.count)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// OnAdd 添加事件
|
||||
@@ -229,6 +359,41 @@ func (this *MemoryList) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncreaseHit 增加点击量
|
||||
func (this *MemoryList) IncreaseHit(hash string) error {
|
||||
this.locker.Lock()
|
||||
|
||||
itemMap, ok := this.itemMaps[this.prefix(hash)]
|
||||
if !ok {
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
item, ok := itemMap[hash]
|
||||
if ok {
|
||||
var week = currentWeek()
|
||||
|
||||
// 交换位置
|
||||
if item.Week > 0 && item.Week != week {
|
||||
wm, ok := this.weekItemMaps[item.Week]
|
||||
if ok {
|
||||
delete(wm, hash)
|
||||
}
|
||||
wm, ok = this.weekItemMaps[week]
|
||||
if ok {
|
||||
wm[hash] = true
|
||||
} else {
|
||||
this.weekItemMaps[week] = map[string]bool{hash: true}
|
||||
}
|
||||
}
|
||||
|
||||
item.IncreaseHit(week)
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *MemoryList) print(t *testing.T) {
|
||||
this.locker.Lock()
|
||||
for _, itemMap := range this.itemMaps {
|
||||
|
||||
@@ -31,6 +31,8 @@ func TestMemoryList_Add(t *testing.T) {
|
||||
})
|
||||
t.Log(list.prefixes)
|
||||
logs.PrintAsJSON(list.itemMaps, t)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Remove(t *testing.T) {
|
||||
@@ -48,6 +50,8 @@ func TestMemoryList_Remove(t *testing.T) {
|
||||
})
|
||||
_ = list.Remove("b")
|
||||
list.print(t)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Purge(t *testing.T) {
|
||||
@@ -73,19 +77,22 @@ func TestMemoryList_Purge(t *testing.T) {
|
||||
ExpiredAt: time.Now().Unix() - 2,
|
||||
HeaderSize: 1024,
|
||||
})
|
||||
_ = list.Purge(100, func(hash string) error {
|
||||
_, _ = list.Purge(100, func(hash string) error {
|
||||
t.Log("delete:", hash)
|
||||
return nil
|
||||
})
|
||||
list.print(t)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_ = list.Purge(100, func(hash string) error {
|
||||
_, _ = list.Purge(100, func(hash string) error {
|
||||
t.Log("delete:", hash)
|
||||
return nil
|
||||
})
|
||||
t.Log(list.purgeIndex)
|
||||
}
|
||||
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_Purge_Large_List(t *testing.T) {
|
||||
@@ -139,7 +146,7 @@ func TestMemoryList_CleanPrefix(t *testing.T) {
|
||||
_ = list.Init()
|
||||
before := time.Now()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
key := "http://www.teaos.cn/hello/" + strconv.Itoa(i/10000) + "/" + strconv.Itoa(i) + ".html"
|
||||
key := "https://www.teaos.cn/hello/" + strconv.Itoa(i/10000) + "/" + strconv.Itoa(i) + ".html"
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
|
||||
Key: key,
|
||||
ExpiredAt: time.Now().Unix() + 3600,
|
||||
@@ -150,7 +157,7 @@ func TestMemoryList_CleanPrefix(t *testing.T) {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
|
||||
before = time.Now()
|
||||
err := list.CleanPrefix("http://www.teaos.cn/hello/10")
|
||||
err := list.CleanPrefix("https://www.teaos.cn/hello/10")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -162,11 +169,77 @@ func TestMemoryList_CleanPrefix(t *testing.T) {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}
|
||||
|
||||
func TestMemoryList_PurgeLFU(t *testing.T) {
|
||||
var list = NewMemoryList().(*MemoryList)
|
||||
list.minWeek = 2704
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
t.Log("current week:", currentWeek())
|
||||
|
||||
_ = list.Add("1", &Item{})
|
||||
_ = list.Add("2", &Item{})
|
||||
_ = list.Add("3", &Item{})
|
||||
_ = list.Add("4", &Item{})
|
||||
_ = list.Add("5", &Item{})
|
||||
_ = list.Add("6", &Item{Week: 2704})
|
||||
_ = list.Add("7", &Item{Week: 2704})
|
||||
_ = list.Add("8", &Item{Week: 2705})
|
||||
|
||||
err := list.PurgeLFU(2, func(hash string) error {
|
||||
t.Log("purge lfu:", hash)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_IncreaseHit(t *testing.T) {
|
||||
var list = NewMemoryList().(*MemoryList)
|
||||
var item = &Item{}
|
||||
item.Week = 2705
|
||||
item.Week2Hits = 100
|
||||
|
||||
_ = list.Add("a", &Item{})
|
||||
_ = list.Add("a", item)
|
||||
t.Log("hits1:", item.Week1Hits, "hits2:", item.Week2Hits, "week:", item.Week)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
|
||||
_ = list.IncreaseHit("a")
|
||||
t.Log("hits1:", item.Week1Hits, "hits2:", item.Week2Hits, "week:", item.Week)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
|
||||
_ = list.IncreaseHit("a")
|
||||
t.Log("hits1:", item.Week1Hits, "hits2:", item.Week2Hits, "week:", item.Week)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
}
|
||||
|
||||
func TestMemoryList_CleanAll(t *testing.T) {
|
||||
var list = NewMemoryList().(*MemoryList)
|
||||
var item = &Item{}
|
||||
item.Week = 2705
|
||||
item.Week2Hits = 100
|
||||
|
||||
_ = list.Add("a", &Item{})
|
||||
_ = list.CleanAll()
|
||||
logs.PrintAsJSON(list.itemMaps, t)
|
||||
logs.PrintAsJSON(list.weekItemMaps, t)
|
||||
t.Log(list.Count())
|
||||
}
|
||||
|
||||
func TestMemoryList_GC(t *testing.T) {
|
||||
list := NewMemoryList().(*MemoryList)
|
||||
_ = list.Init()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
key := "http://www.teaos.cn/hello" + strconv.Itoa(i/100000) + "/" + strconv.Itoa(i) + ".html"
|
||||
key := "https://www.teaos.cn/hello" + strconv.Itoa(i/100000) + "/" + strconv.Itoa(i) + ".html"
|
||||
_ = list.Add(fmt.Sprintf("%d", xxhash.Sum64String(key)), &Item{
|
||||
Key: key,
|
||||
ExpiredAt: 0,
|
||||
|
||||
@@ -135,7 +135,7 @@ func (this *Manager) NewStorageWithPolicy(policy *serverconfigs.HTTPCachePolicy)
|
||||
case serverconfigs.CachePolicyStorageFile:
|
||||
return NewFileStorage(policy)
|
||||
case serverconfigs.CachePolicyStorageMemory:
|
||||
return NewMemoryStorage(policy)
|
||||
return NewMemoryStorage(policy, nil)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ type Reader interface {
|
||||
// TypeName 类型名称
|
||||
TypeName() string
|
||||
|
||||
// ExpiresAt 过期时间
|
||||
ExpiresAt() int64
|
||||
|
||||
// Status 状态码
|
||||
Status() int
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
type FileReader struct {
|
||||
fp *os.File
|
||||
|
||||
expiresAt int64
|
||||
status int
|
||||
headerOffset int64
|
||||
headerSize int
|
||||
@@ -43,6 +44,8 @@ func (this *FileReader) Init() error {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
this.expiresAt = int64(binary.BigEndian.Uint32(buf[:SizeExpiresAt]))
|
||||
|
||||
status := types.Int(string(buf[SizeExpiresAt : SizeExpiresAt+SizeStatus]))
|
||||
if status < 100 || status > 999 {
|
||||
return errors.New("invalid status")
|
||||
@@ -78,6 +81,10 @@ func (this *FileReader) TypeName() string {
|
||||
return "disk"
|
||||
}
|
||||
|
||||
func (this *FileReader) ExpiresAt() int64 {
|
||||
return this.expiresAt
|
||||
}
|
||||
|
||||
func (this *FileReader) Status() int {
|
||||
return this.status
|
||||
}
|
||||
|
||||
@@ -20,6 +20,10 @@ func (this *MemoryReader) TypeName() string {
|
||||
return "memory"
|
||||
}
|
||||
|
||||
func (this *MemoryReader) ExpiresAt() int64 {
|
||||
return this.item.ExpiredAt
|
||||
}
|
||||
|
||||
func (this *MemoryReader) Status() int {
|
||||
return this.item.Status
|
||||
}
|
||||
|
||||
@@ -9,15 +9,20 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -36,6 +41,10 @@ const (
|
||||
SizeMeta = SizeExpiresAt + SizeStatus + SizeURLLength + SizeHeaderLength + SizeBodyLength
|
||||
)
|
||||
|
||||
const (
|
||||
HotItemSize = 1024
|
||||
)
|
||||
|
||||
// FileStorage 文件缓存
|
||||
// 文件结构:
|
||||
// [expires time] | [ status ] | [url length] | [header length] | [body length] | [url] [header data] [body data]
|
||||
@@ -48,13 +57,20 @@ type FileStorage struct {
|
||||
list ListInterface
|
||||
writingKeyMap map[string]bool // key => bool
|
||||
locker sync.RWMutex
|
||||
ticker *utils.Ticker
|
||||
purgeTicker *utils.Ticker
|
||||
|
||||
hotMap map[string]*HotItem // key => count
|
||||
hotMapLocker sync.Mutex
|
||||
lastHotSize int
|
||||
hotTicker *utils.Ticker
|
||||
}
|
||||
|
||||
func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage {
|
||||
return &FileStorage{
|
||||
policy: policy,
|
||||
writingKeyMap: map[string]bool{},
|
||||
hotMap: map[string]*HotItem{},
|
||||
lastHotSize: -1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,12 +181,16 @@ func (this *FileStorage) Init() error {
|
||||
Life: this.policy.Life,
|
||||
MinLife: this.policy.MinLife,
|
||||
MaxLife: this.policy.MaxLife,
|
||||
|
||||
MemoryAutoPurgeCount: this.policy.MemoryAutoPurgeCount,
|
||||
MemoryAutoPurgeInterval: this.policy.MemoryAutoPurgeInterval,
|
||||
MemoryLFUFreePercent: this.policy.MemoryLFUFreePercent,
|
||||
}
|
||||
err = memoryPolicy.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
memoryStorage := NewMemoryStorage(memoryPolicy)
|
||||
memoryStorage := NewMemoryStorage(memoryPolicy, this)
|
||||
err = memoryStorage.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -183,8 +203,12 @@ func (this *FileStorage) Init() error {
|
||||
}
|
||||
|
||||
func (this *FileStorage) OpenReader(key string) (Reader, error) {
|
||||
return this.openReader(key, true)
|
||||
}
|
||||
|
||||
func (this *FileStorage) openReader(key string, allowMemory bool) (Reader, error) {
|
||||
// 先尝试内存缓存
|
||||
if this.memoryStorage != nil {
|
||||
if allowMemory && this.memoryStorage != nil {
|
||||
reader, err := this.memoryStorage.OpenReader(key)
|
||||
if err == nil {
|
||||
return reader, err
|
||||
@@ -227,6 +251,45 @@ func (this *FileStorage) OpenReader(key string) (Reader, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 增加点击量
|
||||
// 1/1000采样
|
||||
if allowMemory {
|
||||
var rate = this.policy.PersistenceHitSampleRate
|
||||
if rate <= 0 {
|
||||
rate = 1000
|
||||
}
|
||||
if this.lastHotSize == 0 {
|
||||
// 自动降低采样率来增加热点数据的缓存几率
|
||||
rate = rate / 10
|
||||
}
|
||||
if rands.Int(0, rate) == 0 {
|
||||
var hitErr = this.list.IncreaseHit(hash)
|
||||
if hitErr != nil {
|
||||
// 此错误可以忽略
|
||||
remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error())
|
||||
}
|
||||
|
||||
// 增加到热点
|
||||
// 这里不收录缓存尺寸过大的文件
|
||||
if this.memoryStorage != nil && reader.BodySize() > 0 && reader.BodySize() < 128*1024*1024 {
|
||||
this.hotMapLocker.Lock()
|
||||
hotItem, ok := this.hotMap[key]
|
||||
if ok {
|
||||
hotItem.Hits++
|
||||
hotItem.ExpiresAt = reader.expiresAt
|
||||
} else if len(this.hotMap) < HotItemSize { // 控制数量
|
||||
this.hotMap[key] = &HotItem{
|
||||
Key: key,
|
||||
ExpiresAt: reader.ExpiresAt(),
|
||||
Status: reader.Status(),
|
||||
Hits: 1,
|
||||
}
|
||||
}
|
||||
this.hotMapLocker.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isOk = true
|
||||
return reader, nil
|
||||
}
|
||||
@@ -398,7 +461,7 @@ func (this *FileStorage) AddToList(item *Item) {
|
||||
}
|
||||
}
|
||||
|
||||
item.MetaSize = SizeMeta
|
||||
item.MetaSize = SizeMeta + 128
|
||||
hash := stringutil.Md5(item.Key)
|
||||
err := this.list.Add(hash, item)
|
||||
if err != nil && !strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
@@ -555,8 +618,11 @@ func (this *FileStorage) Stop() {
|
||||
}
|
||||
|
||||
_ = this.list.Reset()
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
if this.purgeTicker != nil {
|
||||
this.purgeTicker.Stop()
|
||||
}
|
||||
if this.hotTicker != nil {
|
||||
this.hotTicker.Stop()
|
||||
}
|
||||
|
||||
_ = this.list.Close()
|
||||
@@ -606,30 +672,47 @@ func (this *FileStorage) initList() error {
|
||||
}
|
||||
|
||||
// 使用异步防止阻塞主线程
|
||||
go func() {
|
||||
/**go func() {
|
||||
dir := this.dir()
|
||||
|
||||
// 清除tmp
|
||||
files, err := filepath.Glob(dir + "/*/*/*.cache.tmp")
|
||||
if err == nil && len(files) > 0 {
|
||||
for _, path := range files {
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// TODO 需要一个更加高效的实现
|
||||
}()**/
|
||||
|
||||
// 启动定时清理任务
|
||||
this.ticker = utils.NewTicker(30 * time.Second)
|
||||
var autoPurgeInterval = this.policy.PersistenceAutoPurgeInterval
|
||||
if autoPurgeInterval <= 0 {
|
||||
autoPurgeInterval = 30
|
||||
if Tea.IsTesting() {
|
||||
autoPurgeInterval = 10
|
||||
}
|
||||
}
|
||||
this.purgeTicker = utils.NewTicker(time.Duration(autoPurgeInterval) * time.Second)
|
||||
events.On(events.EventQuit, func() {
|
||||
remotelogs.Println("CACHE", "quit clean timer")
|
||||
var ticker = this.ticker
|
||||
var ticker = this.purgeTicker
|
||||
if ticker != nil {
|
||||
ticker.Stop()
|
||||
}
|
||||
})
|
||||
go func() {
|
||||
for this.ticker.Next() {
|
||||
this.purgeLoop()
|
||||
for this.purgeTicker.Next() {
|
||||
trackers.Run("FILE_CACHE_STORAGE_PURGE_LOOP", func() {
|
||||
this.purgeLoop()
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// 热点处理任务
|
||||
this.hotTicker = utils.NewTicker(1 * time.Minute)
|
||||
if Tea.IsTesting() {
|
||||
this.hotTicker = utils.NewTicker(10 * time.Second)
|
||||
}
|
||||
go func() {
|
||||
for this.hotTicker.Next() {
|
||||
trackers.Run("FILE_CACHE_STORAGE_HOT_LOOP", func() {
|
||||
this.hotLoop()
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -730,16 +813,183 @@ func (this *FileStorage) decodeFile(path string) (*Item, error) {
|
||||
|
||||
// 清理任务
|
||||
func (this *FileStorage) purgeLoop() {
|
||||
err := this.list.Purge(1000, func(hash string) error {
|
||||
path := this.hashPath(hash)
|
||||
err := os.Remove(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
remotelogs.Error("CACHE", "purge '"+path+"' error: "+err.Error())
|
||||
// 计算是否应该开启LFU清理
|
||||
var capacityBytes = this.policy.CapacityBytes()
|
||||
var startLFU = false
|
||||
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
|
||||
var lfuFreePercent = this.policy.PersistenceLFUFreePercent
|
||||
if lfuFreePercent <= 0 {
|
||||
lfuFreePercent = 5
|
||||
}
|
||||
if capacityBytes > 0 {
|
||||
if lfuFreePercent < 100 {
|
||||
if usedPercent >= 100-lfuFreePercent {
|
||||
startLFU = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 清理过期
|
||||
{
|
||||
var times = 1
|
||||
|
||||
// 空闲时间多清理
|
||||
if utils.SharedFreeHoursManager.IsFreeHour() {
|
||||
times = 5
|
||||
}
|
||||
|
||||
// 处于LFU阈值时,多清理
|
||||
if startLFU {
|
||||
times = 5
|
||||
}
|
||||
|
||||
var purgeCount = this.policy.PersistenceAutoPurgeCount
|
||||
if purgeCount <= 0 {
|
||||
purgeCount = 1000
|
||||
}
|
||||
for i := 0; i < times; i++ {
|
||||
countFound, err := this.list.Purge(purgeCount, func(hash string) error {
|
||||
path := this.hashPath(hash)
|
||||
err := os.Remove(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
remotelogs.Error("CACHE", "purge '"+path+"' error: "+err.Error())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "purge file storage failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if countFound < purgeCount {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// 磁盘空间不足时,清除老旧的缓存
|
||||
if startLFU {
|
||||
var total, _ = this.list.Count()
|
||||
if total > 0 {
|
||||
var count = types.Int(math.Ceil(float64(total) * float64(lfuFreePercent*2) / 100))
|
||||
if count > 0 {
|
||||
// 限制单次清理的条数,防止占用太多系统资源
|
||||
if count > 2000 {
|
||||
count = 2000
|
||||
}
|
||||
|
||||
remotelogs.Println("CACHE", "LFU purge policy '"+this.policy.Name+"' id: "+types.String(this.policy.Id)+", count: "+types.String(count))
|
||||
err := this.list.PurgeLFU(count, func(hash string) error {
|
||||
path := this.hashPath(hash)
|
||||
err := os.Remove(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
remotelogs.Error("CACHE", "purge '"+path+"' error: "+err.Error())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "purge file storage in LFU failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 热点数据任务
|
||||
func (this *FileStorage) hotLoop() {
|
||||
var memoryStorage = this.memoryStorage
|
||||
if memoryStorage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
this.hotMapLocker.Lock()
|
||||
if len(this.hotMap) == 0 {
|
||||
this.hotMapLocker.Unlock()
|
||||
this.lastHotSize = 0
|
||||
return
|
||||
}
|
||||
|
||||
this.lastHotSize = len(this.hotMap)
|
||||
|
||||
var result = []*HotItem{} // [ {key: ..., hits: ...}, ... ]
|
||||
for _, v := range this.hotMap {
|
||||
result = append(result, v)
|
||||
}
|
||||
|
||||
this.hotMap = map[string]*HotItem{}
|
||||
this.hotMapLocker.Unlock()
|
||||
|
||||
// 取Top10
|
||||
if len(result) > 0 {
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Hits > result[j].Hits
|
||||
})
|
||||
var size = 1
|
||||
if len(result) < 10 {
|
||||
size = 1
|
||||
} else {
|
||||
size = len(result) / 10
|
||||
}
|
||||
|
||||
var buf = make([]byte, 32*1024)
|
||||
for _, item := range result[:size] {
|
||||
reader, err := this.openReader(item.Key, false)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if reader == nil {
|
||||
continue
|
||||
}
|
||||
if reader.ExpiresAt() <= time.Now().Unix() {
|
||||
continue
|
||||
}
|
||||
|
||||
writer, err := this.memoryStorage.openWriter(item.Key, item.ExpiresAt, item.Status, false)
|
||||
if err != nil {
|
||||
if !CanIgnoreErr(err) {
|
||||
remotelogs.Error("CACHE", "transfer hot item failed: "+err.Error())
|
||||
}
|
||||
_ = reader.Close()
|
||||
continue
|
||||
}
|
||||
if writer == nil {
|
||||
_ = reader.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
err = reader.ReadHeader(buf, func(n int) (goNext bool, err error) {
|
||||
_, err = writer.WriteHeader(buf[:n])
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
_ = reader.Close()
|
||||
_ = writer.Discard()
|
||||
continue
|
||||
}
|
||||
|
||||
err = reader.ReadBody(buf, func(n int) (goNext bool, err error) {
|
||||
_, err = writer.Write(buf[:n])
|
||||
return
|
||||
})
|
||||
if err != nil {
|
||||
_ = reader.Close()
|
||||
_ = writer.Discard()
|
||||
continue
|
||||
}
|
||||
|
||||
this.memoryStorage.AddToList(&Item{
|
||||
Type: writer.ItemType(),
|
||||
Key: item.Key,
|
||||
ExpiredAt: item.ExpiresAt,
|
||||
HeaderSize: writer.HeaderSize(),
|
||||
BodySize: writer.BodySize(),
|
||||
})
|
||||
|
||||
_ = reader.Close()
|
||||
_ = writer.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "purge file storage failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -517,3 +517,16 @@ func BenchmarkFileStorage_Read(b *testing.B) {
|
||||
_ = reader.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFileStorage_KeyPath(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var storage = &FileStorage{
|
||||
cacheConfig: &serverconfigs.HTTPFileCacheStorage{},
|
||||
policy: &serverconfigs.HTTPCachePolicy{Id: 1},
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = storage.keyPath(strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,13 @@ import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/cespare/xxhash"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"math"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -26,22 +31,37 @@ func (this *MemoryItem) IsExpired() bool {
|
||||
}
|
||||
|
||||
type MemoryStorage struct {
|
||||
policy *serverconfigs.HTTPCachePolicy
|
||||
list ListInterface
|
||||
locker *sync.RWMutex
|
||||
valuesMap map[uint64]*MemoryItem
|
||||
ticker *utils.Ticker
|
||||
purgeDuration time.Duration
|
||||
parentStorage StorageInterface
|
||||
|
||||
policy *serverconfigs.HTTPCachePolicy
|
||||
list ListInterface
|
||||
locker *sync.RWMutex
|
||||
|
||||
valuesMap map[uint64]*MemoryItem // hash => item
|
||||
dirtyChan chan string // hash chan
|
||||
|
||||
purgeTicker *utils.Ticker
|
||||
|
||||
totalSize int64
|
||||
writingKeyMap map[string]bool // key => bool
|
||||
}
|
||||
|
||||
func NewMemoryStorage(policy *serverconfigs.HTTPCachePolicy) *MemoryStorage {
|
||||
func NewMemoryStorage(policy *serverconfigs.HTTPCachePolicy, parentStorage StorageInterface) *MemoryStorage {
|
||||
var dirtyChan chan string
|
||||
if parentStorage != nil {
|
||||
var queueSize = policy.MemoryAutoFlushQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = 2048
|
||||
}
|
||||
dirtyChan = make(chan string, queueSize)
|
||||
}
|
||||
return &MemoryStorage{
|
||||
parentStorage: parentStorage,
|
||||
policy: policy,
|
||||
list: NewMemoryList(),
|
||||
locker: &sync.RWMutex{},
|
||||
valuesMap: map[uint64]*MemoryItem{},
|
||||
dirtyChan: dirtyChan,
|
||||
writingKeyMap: map[string]bool{},
|
||||
}
|
||||
}
|
||||
@@ -57,15 +77,25 @@ func (this *MemoryStorage) Init() error {
|
||||
atomic.AddInt64(&this.totalSize, -item.TotalSize())
|
||||
})
|
||||
|
||||
if this.purgeDuration <= 0 {
|
||||
this.purgeDuration = 10 * time.Second
|
||||
var autoPurgeInterval = this.policy.MemoryAutoPurgeInterval
|
||||
if autoPurgeInterval <= 0 {
|
||||
autoPurgeInterval = 5
|
||||
}
|
||||
|
||||
// 启动定时清理任务
|
||||
this.ticker = utils.NewTicker(this.purgeDuration)
|
||||
this.purgeTicker = utils.NewTicker(time.Duration(autoPurgeInterval) * time.Second)
|
||||
go func() {
|
||||
for this.ticker.Next() {
|
||||
for this.purgeTicker.Next() {
|
||||
var tr = trackers.Begin("MEMORY_CACHE_STORAGE_PURGE_LOOP")
|
||||
this.purgeLoop()
|
||||
tr.End()
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动定时Flush memory to disk任务
|
||||
go func() {
|
||||
for hash := range this.dirtyChan {
|
||||
this.flushItem(hash)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -91,6 +121,18 @@ func (this *MemoryStorage) OpenReader(key string) (Reader, error) {
|
||||
return nil, err
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
|
||||
// 增加点击量
|
||||
// 1/1000采样
|
||||
// TODO 考虑是否在缓存策略里设置
|
||||
if rands.Int(0, 1000) == 0 {
|
||||
var hitErr = this.list.IncreaseHit(types.String(hash))
|
||||
if hitErr != nil {
|
||||
// 此错误可以忽略
|
||||
remotelogs.Error("CACHE", "increase hit failed: "+hitErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
@@ -102,6 +144,10 @@ func (this *MemoryStorage) OpenReader(key string) (Reader, error) {
|
||||
|
||||
// OpenWriter 打开缓存写入器等待写入
|
||||
func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int) (Writer, error) {
|
||||
return this.openWriter(key, expiredAt, status, true)
|
||||
}
|
||||
|
||||
func (this *MemoryStorage) openWriter(key string, expiredAt int64, status int, isDirty bool) (Writer, error) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
@@ -145,7 +191,7 @@ func (this *MemoryStorage) OpenWriter(key string, expiredAt int64, status int) (
|
||||
}
|
||||
|
||||
isWriting = true
|
||||
return NewMemoryWriter(this.valuesMap, key, expiredAt, status, this.locker, func() {
|
||||
return NewMemoryWriter(this, key, expiredAt, status, isDirty, func() {
|
||||
this.locker.Lock()
|
||||
delete(this.writingKeyMap, key)
|
||||
this.locker.Unlock()
|
||||
@@ -210,14 +256,21 @@ func (this *MemoryStorage) Stop() {
|
||||
this.valuesMap = map[uint64]*MemoryItem{}
|
||||
this.writingKeyMap = map[string]bool{}
|
||||
_ = this.list.Reset()
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
if this.purgeTicker != nil {
|
||||
this.purgeTicker.Stop()
|
||||
}
|
||||
|
||||
if this.parentStorage != nil && this.dirtyChan != nil {
|
||||
close(this.dirtyChan)
|
||||
}
|
||||
|
||||
_ = this.list.Close()
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
// 回收内存
|
||||
runtime.GC()
|
||||
|
||||
remotelogs.Println("CACHE", "close memory storage '"+strconv.FormatInt(this.policy.Id, 10)+"'")
|
||||
}
|
||||
|
||||
@@ -228,7 +281,7 @@ func (this *MemoryStorage) Policy() *serverconfigs.HTTPCachePolicy {
|
||||
|
||||
// AddToList 将缓存添加到列表
|
||||
func (this *MemoryStorage) AddToList(item *Item) {
|
||||
item.MetaSize = int64(len(item.Key)) + 32 /** 32是我们评估的数据结构的长度 **/
|
||||
item.MetaSize = int64(len(item.Key)) + 128 /** 128是我们评估的数据结构的长度 **/
|
||||
hash := fmt.Sprintf("%d", this.hash(item.Key))
|
||||
_ = this.list.Add(hash, item)
|
||||
}
|
||||
@@ -250,7 +303,28 @@ func (this *MemoryStorage) hash(key string) uint64 {
|
||||
|
||||
// 清理任务
|
||||
func (this *MemoryStorage) purgeLoop() {
|
||||
_ = this.list.Purge(2048, func(hash string) error {
|
||||
// 计算是否应该开启LFU清理
|
||||
var capacityBytes = this.policy.CapacityBytes()
|
||||
var startLFU = false
|
||||
var usedPercent = float32(this.TotalMemorySize()*100) / float32(capacityBytes)
|
||||
var lfuFreePercent = this.policy.MemoryLFUFreePercent
|
||||
if lfuFreePercent <= 0 {
|
||||
lfuFreePercent = 5
|
||||
}
|
||||
if capacityBytes > 0 {
|
||||
if lfuFreePercent < 100 {
|
||||
if usedPercent >= 100-lfuFreePercent {
|
||||
startLFU = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 清理过期
|
||||
var purgeCount = this.policy.MemoryAutoPurgeCount
|
||||
if purgeCount <= 0 {
|
||||
purgeCount = 2000
|
||||
}
|
||||
_, _ = this.list.Purge(purgeCount, func(hash string) error {
|
||||
uintHash, err := strconv.ParseUint(hash, 10, 64)
|
||||
if err == nil {
|
||||
this.locker.Lock()
|
||||
@@ -259,6 +333,92 @@ func (this *MemoryStorage) purgeLoop() {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// LFU
|
||||
if startLFU {
|
||||
var total, _ = this.list.Count()
|
||||
if total > 0 {
|
||||
var count = types.Int(math.Ceil(float64(total) * float64(lfuFreePercent*2) / 100))
|
||||
if count > 0 {
|
||||
// 限制单次清理的条数,防止占用太多系统资源
|
||||
if count > 2000 {
|
||||
count = 2000
|
||||
}
|
||||
|
||||
// 这里不提示LFU,因为此事件将会非常频繁
|
||||
|
||||
err := this.list.PurgeLFU(count, func(hash string) error {
|
||||
uintHash, err := strconv.ParseUint(hash, 10, 64)
|
||||
if err == nil {
|
||||
this.locker.Lock()
|
||||
delete(this.valuesMap, uintHash)
|
||||
this.locker.Unlock()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "purge memory storage in LFU failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush任务
|
||||
func (this *MemoryStorage) flushItem(key string) {
|
||||
if this.parentStorage == nil {
|
||||
return
|
||||
}
|
||||
var hash = this.hash(key)
|
||||
|
||||
this.locker.RLock()
|
||||
item, ok := this.valuesMap[hash]
|
||||
this.locker.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !item.IsDone || item.IsExpired() {
|
||||
return
|
||||
}
|
||||
|
||||
writer, err := this.parentStorage.OpenWriter(key, item.ExpiredAt, item.Status)
|
||||
if err != nil {
|
||||
if !CanIgnoreErr(err) {
|
||||
remotelogs.Error("CACHE", "flush items failed: open writer failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = writer.WriteHeader(item.HeaderValue)
|
||||
if err != nil {
|
||||
_ = writer.Discard()
|
||||
remotelogs.Error("CACHE", "flush items failed: write header failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_, err = writer.Write(item.BodyValue)
|
||||
if err != nil {
|
||||
_ = writer.Discard()
|
||||
remotelogs.Error("CACHE", "flush items failed: writer body failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
_ = writer.Discard()
|
||||
remotelogs.Error("CACHE", "flush items failed: close writer failed: "+err.Error())
|
||||
}
|
||||
|
||||
this.parentStorage.AddToList(&Item{
|
||||
Type: writer.ItemType(),
|
||||
Key: key,
|
||||
ExpiredAt: item.ExpiredAt,
|
||||
HeaderSize: writer.HeaderSize(),
|
||||
BodySize: writer.BodySize(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *MemoryStorage) memoryCapacityBytes() int64 {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMemoryStorage_OpenWriter(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
|
||||
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200)
|
||||
if err != nil {
|
||||
@@ -88,7 +88,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_OpenReaderLock(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
_ = storage.Init()
|
||||
|
||||
var h = storage.hash("test")
|
||||
@@ -101,7 +101,7 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Delete(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200)
|
||||
if err != nil {
|
||||
@@ -123,7 +123,7 @@ func TestMemoryStorage_Delete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Stat(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
expiredAt := time.Now().Unix() + 60
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", expiredAt, 200)
|
||||
@@ -160,7 +160,7 @@ func TestMemoryStorage_Stat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_CleanAll(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
expiredAt := time.Now().Unix() + 60
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", expiredAt, 200)
|
||||
@@ -195,7 +195,7 @@ func TestMemoryStorage_CleanAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Purge(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
expiredAt := time.Now().Unix() + 60
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", expiredAt, 200)
|
||||
@@ -230,8 +230,9 @@ func TestMemoryStorage_Purge(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Expire(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage.purgeDuration = 5 * time.Second
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
|
||||
MemoryAutoPurgeInterval: 5,
|
||||
}, nil)
|
||||
err := storage.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -255,7 +256,7 @@ func TestMemoryStorage_Expire(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Locker(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{})
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
err := storage.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -2,36 +2,36 @@ package caches
|
||||
|
||||
import (
|
||||
"github.com/cespare/xxhash"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MemoryWriter struct {
|
||||
storage *MemoryStorage
|
||||
|
||||
key string
|
||||
expiredAt int64
|
||||
m map[uint64]*MemoryItem
|
||||
locker *sync.RWMutex
|
||||
headerSize int64
|
||||
bodySize int64
|
||||
status int
|
||||
isDirty bool
|
||||
|
||||
hash uint64
|
||||
item *MemoryItem
|
||||
endFunc func()
|
||||
}
|
||||
|
||||
func NewMemoryWriter(m map[uint64]*MemoryItem, key string, expiredAt int64, status int, locker *sync.RWMutex, endFunc func()) *MemoryWriter {
|
||||
func NewMemoryWriter(memoryStorage *MemoryStorage, key string, expiredAt int64, status int, isDirty bool, endFunc func()) *MemoryWriter {
|
||||
w := &MemoryWriter{
|
||||
m: m,
|
||||
storage: memoryStorage,
|
||||
key: key,
|
||||
expiredAt: expiredAt,
|
||||
locker: locker,
|
||||
item: &MemoryItem{
|
||||
ExpiredAt: expiredAt,
|
||||
ModifiedAt: time.Now().Unix(),
|
||||
Status: status,
|
||||
},
|
||||
status: status,
|
||||
isDirty: isDirty,
|
||||
endFunc: endFunc,
|
||||
}
|
||||
w.hash = w.calculateHash(key)
|
||||
@@ -72,10 +72,19 @@ func (this *MemoryWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
this.storage.locker.Lock()
|
||||
this.item.IsDone = true
|
||||
this.m[this.hash] = this.item
|
||||
this.locker.Unlock()
|
||||
this.storage.valuesMap[this.hash] = this.item
|
||||
if this.isDirty {
|
||||
if this.storage.parentStorage != nil {
|
||||
select {
|
||||
case this.storage.dirtyChan <- this.key:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
this.storage.locker.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -85,9 +94,9 @@ func (this *MemoryWriter) Discard() error {
|
||||
// 需要在Locker之外
|
||||
defer this.endFunc()
|
||||
|
||||
this.locker.Lock()
|
||||
delete(this.m, this.hash)
|
||||
this.locker.Unlock()
|
||||
this.storage.locker.Lock()
|
||||
delete(this.storage.valuesMap, this.hash)
|
||||
this.storage.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.3.3"
|
||||
Version = "0.3.6"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
12
internal/const/vars.go
Normal file
12
internal/const/vars.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package teaconst
|
||||
|
||||
var (
|
||||
// 流量统计
|
||||
|
||||
InTrafficBytes = uint64(0)
|
||||
OutTrafficBytes = uint64(0)
|
||||
|
||||
NodeId int64 = 0
|
||||
)
|
||||
5
internal/iplibrary/README.md
Normal file
5
internal/iplibrary/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# IPList
|
||||
List Check Order:
|
||||
~~~
|
||||
Global List --> Node List--> Server List --> WAF List --> Bind List
|
||||
~~~
|
||||
@@ -13,7 +13,7 @@ func (this *BaseAction) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理HTTP请求
|
||||
// DoHTTP 处理HTTP请求
|
||||
func (this *BaseAction) DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package iplibrary
|
||||
|
||||
// 是否是致命错误
|
||||
// FataError 是否是致命错误
|
||||
type FataError struct {
|
||||
err string
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Firewalld动作管理
|
||||
// FirewalldAction Firewalld动作管理
|
||||
// 常用命令:
|
||||
// - 查询列表: firewall-cmd --list-all
|
||||
// - 添加IP:firewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
|
||||
@@ -20,6 +20,8 @@ type FirewalldAction struct {
|
||||
BaseAction
|
||||
|
||||
config *firewallconfigs.FirewallActionFirewalldConfig
|
||||
|
||||
firewalldNotFound bool
|
||||
}
|
||||
|
||||
func NewFirewalldAction() *FirewalldAction {
|
||||
@@ -82,6 +84,10 @@ func (this *FirewalldAction) runActionSingleIP(action string, listType IPListTyp
|
||||
if len(path) == 0 {
|
||||
path, err = exec.LookPath("firewall-cmd")
|
||||
if err != nil {
|
||||
if this.firewalldNotFound {
|
||||
return nil
|
||||
}
|
||||
this.firewalldNotFound = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -126,10 +132,12 @@ func (this *FirewalldAction) runActionSingleIP(action string, listType IPListTyp
|
||||
}
|
||||
|
||||
args := []string{opt}
|
||||
if item.ExpiredAt > timestamp {
|
||||
args = append(args, "--timeout="+fmt.Sprintf("%d", item.ExpiredAt-timestamp)+"s")
|
||||
} else {
|
||||
// TODO 思考是否需要permanent,不然--reload之后会丢失
|
||||
if action == "addItem" {
|
||||
if item.ExpiredAt > timestamp {
|
||||
args = append(args, "--timeout="+fmt.Sprintf("%d", item.ExpiredAt-timestamp)+"s")
|
||||
} else {
|
||||
// TODO 思考是否需要permanent,不然--reload之后会丢失
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
|
||||
@@ -6,19 +6,19 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HTML动作
|
||||
// HTMLAction HTML动作
|
||||
type HTMLAction struct {
|
||||
BaseAction
|
||||
|
||||
config *firewallconfigs.FirewallActionHTMLConfig
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
// NewHTMLAction 获取新对象
|
||||
func NewHTMLAction() *HTMLAction {
|
||||
return &HTMLAction{}
|
||||
}
|
||||
|
||||
// 初始化
|
||||
// Init 初始化
|
||||
func (this *HTMLAction) Init(config *firewallconfigs.FirewallActionConfig) error {
|
||||
this.config = &firewallconfigs.FirewallActionHTMLConfig{}
|
||||
err := this.convertParams(config.Params, this.config)
|
||||
@@ -28,22 +28,22 @@ func (this *HTMLAction) Init(config *firewallconfigs.FirewallActionConfig) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加
|
||||
// AddItem 添加
|
||||
func (this *HTMLAction) AddItem(listType IPListType, item *pb.IPItem) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除
|
||||
// DeleteItem 删除
|
||||
func (this *HTMLAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 关闭
|
||||
// Close 关闭
|
||||
func (this *HTMLAction) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理HTTP请求
|
||||
// DoHTTP 处理HTTP请求
|
||||
func (this *HTMLAction) DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error) {
|
||||
if this.config == nil {
|
||||
goNext = true
|
||||
|
||||
@@ -7,18 +7,18 @@ import (
|
||||
)
|
||||
|
||||
type ActionInterface interface {
|
||||
// 初始化
|
||||
// Init 初始化
|
||||
Init(config *firewallconfigs.FirewallActionConfig) error
|
||||
|
||||
// 添加
|
||||
// AddItem 添加
|
||||
AddItem(listType IPListType, item *pb.IPItem) error
|
||||
|
||||
// 删除
|
||||
// DeleteItem 删除
|
||||
DeleteItem(listType IPListType, item *pb.IPItem) error
|
||||
|
||||
// 关闭
|
||||
// Close 关闭
|
||||
Close() error
|
||||
|
||||
// 处理HTTP请求
|
||||
// DoHTTP 处理HTTP请求
|
||||
DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error)
|
||||
}
|
||||
|
||||
@@ -5,13 +5,15 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPSet动作
|
||||
// IPSetAction IPSet动作
|
||||
// 相关命令:
|
||||
// - 利用Firewalld管理set:
|
||||
// - 添加:firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
|
||||
@@ -23,14 +25,21 @@ import (
|
||||
// - 添加Item:ipset add edge_ip_list 192.168.2.32 timeout 30
|
||||
// - 删除Item: ipset del edge_ip_list 192.168.2.32
|
||||
// - 创建set:ipset create edge_ip_list hash:ip timeout 0
|
||||
// - 查看统计:ipset -t list edge_black_list
|
||||
// - 删除set:ipset destroy edge_black_list
|
||||
type IPSetAction struct {
|
||||
BaseAction
|
||||
|
||||
config *firewallconfigs.FirewallActionIPSetConfig
|
||||
config *firewallconfigs.FirewallActionIPSetConfig
|
||||
errorBuf *bytes.Buffer
|
||||
|
||||
ipsetNotfound bool
|
||||
}
|
||||
|
||||
func NewIPSetAction() *IPSetAction {
|
||||
return &IPSetAction{}
|
||||
return &IPSetAction{
|
||||
errorBuf: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) error {
|
||||
@@ -54,7 +63,7 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
return err
|
||||
}
|
||||
{
|
||||
cmd := exec.Command(path, "create", this.config.WhiteName, "hash:ip", "timeout", "0")
|
||||
cmd := exec.Command(path, "create", this.config.WhiteName, "hash:ip", "timeout", "0", "maxelem", "1000000")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -68,7 +77,7 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
}
|
||||
}
|
||||
{
|
||||
cmd := exec.Command(path, "create", this.config.BlackName, "hash:ip", "timeout", "0")
|
||||
cmd := exec.Command(path, "create", this.config.BlackName, "hash:ip", "timeout", "0", "maxelem", "1000000")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -163,24 +172,39 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
}
|
||||
|
||||
{
|
||||
cmd := exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", this.config.WhiteName, "src", "-j", "ACCEPT")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
// 检查规则是否存在
|
||||
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", this.config.WhiteName, "src", "-j", "ACCEPT")
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
var exists = err == nil
|
||||
|
||||
// 添加规则
|
||||
if !exists {
|
||||
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", this.config.WhiteName, "src", "-j", "ACCEPT")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
cmd := exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", this.config.BlackName, "src", "-j", "REJECT")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
// 检查规则是否存在
|
||||
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", this.config.BlackName, "src", "-j", "REJECT")
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
var exists = err == nil
|
||||
|
||||
if !exists {
|
||||
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", this.config.BlackName, "src", "-j", "REJECT")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,6 +236,16 @@ func (this *IPSetAction) runAction(action string, listType IPListType, item *pb.
|
||||
return nil
|
||||
}
|
||||
for _, cidr := range cidrList {
|
||||
index := strings.Index(cidr, "/")
|
||||
if index <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 只支持/24以下的
|
||||
if types.Int(cidr[index+1:]) < 24 {
|
||||
continue
|
||||
}
|
||||
|
||||
item.IpFrom = cidr
|
||||
item.IpTo = ""
|
||||
err := this.runActionSingleIP(action, listType, item)
|
||||
@@ -246,6 +280,11 @@ func (this *IPSetAction) runActionSingleIP(action string, listType IPListType, i
|
||||
if len(path) == 0 {
|
||||
path, err = exec.LookPath("ipset")
|
||||
if err != nil {
|
||||
// 找不到ipset命令错误只提示一次
|
||||
if this.ipsetNotfound {
|
||||
return nil
|
||||
}
|
||||
this.ipsetNotfound = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -258,19 +297,30 @@ func (this *IPSetAction) runActionSingleIP(action string, listType IPListType, i
|
||||
case "deleteItem":
|
||||
args = append(args, "del")
|
||||
}
|
||||
args = append(args, listName, item.IpFrom)
|
||||
timestamp := time.Now().Unix()
|
||||
if item.ExpiredAt > timestamp {
|
||||
args = append(args, "timeout", strconv.FormatInt(item.ExpiredAt-timestamp, 10))
|
||||
}
|
||||
|
||||
//logs.Println(args)
|
||||
args = append(args, listName, item.IpFrom)
|
||||
if action == "addItem" {
|
||||
timestamp := time.Now().Unix()
|
||||
if item.ExpiredAt > timestamp {
|
||||
args = append(args, "timeout", strconv.FormatInt(item.ExpiredAt-timestamp, 10))
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
// MAC OS直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
this.errorBuf.Reset()
|
||||
cmd := exec.Command(path, args...)
|
||||
return cmd.Run()
|
||||
cmd.Stderr = this.errorBuf
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
var errString = this.errorBuf.String()
|
||||
if action == "deleteItem" && strings.Contains(errString, "not added") {
|
||||
return nil
|
||||
}
|
||||
return errors.New(strings.TrimSpace(errString))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,15 +9,18 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// IPTables动作
|
||||
// IPTablesAction IPTables动作
|
||||
// 相关命令:
|
||||
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
|
||||
// iptables -A INPUT -s "192.168.2.32" -j REJECT
|
||||
// iptables -D ...
|
||||
// iptables -D INPUT ...
|
||||
// iptables -F INPUT
|
||||
type IPTablesAction struct {
|
||||
BaseAction
|
||||
|
||||
config *firewallconfigs.FirewallActionIPTablesConfig
|
||||
|
||||
iptablesNotFound bool
|
||||
}
|
||||
|
||||
func NewIPTablesAction() *IPTablesAction {
|
||||
@@ -76,6 +79,10 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
|
||||
if len(path) == 0 {
|
||||
path, err = exec.LookPath("iptables")
|
||||
if err != nil {
|
||||
if this.iptablesNotFound {
|
||||
return nil
|
||||
}
|
||||
this.iptablesNotFound = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
var SharedActionManager = NewActionManager()
|
||||
|
||||
// 动作管理器定义
|
||||
// ActionManager 动作管理器定义
|
||||
type ActionManager struct {
|
||||
locker sync.Mutex
|
||||
|
||||
@@ -23,7 +23,7 @@ type ActionManager struct {
|
||||
instanceMap map[int64]ActionInterface // id => instance
|
||||
}
|
||||
|
||||
// 获取动作管理对象
|
||||
// NewActionManager 获取动作管理对象
|
||||
func NewActionManager() *ActionManager {
|
||||
return &ActionManager{
|
||||
configMap: map[int64]*firewallconfigs.FirewallActionConfig{},
|
||||
@@ -31,7 +31,7 @@ func NewActionManager() *ActionManager {
|
||||
}
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
// UpdateActions 更新配置
|
||||
func (this *ActionManager) UpdateActions(actions []*firewallconfigs.FirewallActionConfig) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
@@ -108,14 +108,14 @@ func (this *ActionManager) UpdateActions(actions []*firewallconfigs.FirewallActi
|
||||
}
|
||||
}
|
||||
|
||||
// 查找事件对应的动作
|
||||
// FindEventActions 查找事件对应的动作
|
||||
func (this *ActionManager) FindEventActions(eventLevel string) []ActionInterface {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
return this.eventMap[eventLevel]
|
||||
}
|
||||
|
||||
// 执行添加IP动作
|
||||
// AddItem 执行添加IP动作
|
||||
func (this *ActionManager) AddItem(listType IPListType, item *pb.IPItem) {
|
||||
instances, ok := this.eventMap[item.EventLevel]
|
||||
if ok {
|
||||
@@ -128,7 +128,7 @@ func (this *ActionManager) AddItem(listType IPListType, item *pb.IPItem) {
|
||||
}
|
||||
}
|
||||
|
||||
// 执行删除IP动作
|
||||
// DeleteItem 执行删除IP动作
|
||||
func (this *ActionManager) DeleteItem(listType IPListType, item *pb.IPItem) {
|
||||
instances, ok := this.eventMap[item.EventLevel]
|
||||
if ok {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// 脚本命令动作
|
||||
// ScriptAction 脚本命令动作
|
||||
type ScriptAction struct {
|
||||
BaseAction
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func (this *IPItem) Contains(ip uint64) bool {
|
||||
case IPItemTypeIPv6:
|
||||
return this.containsIPv6(ip)
|
||||
case IPItemTypeAll:
|
||||
return this.containsAll(ip)
|
||||
return this.containsAll()
|
||||
default:
|
||||
return this.containsIPv4(ip)
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (this *IPItem) containsIPv6(ip uint64) bool {
|
||||
}
|
||||
|
||||
// 检查是否包所有IP
|
||||
func (this *IPItem) containsAll(ip uint64) bool {
|
||||
func (this *IPItem) containsAll() bool {
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var GlobalBlackIPList = NewIPList()
|
||||
var GlobalWhiteIPList = NewIPList()
|
||||
|
||||
// IPList IP名单
|
||||
// TODO IP名单可以分片关闭,这样让每一片的数据量减少,查询更快
|
||||
type IPList struct {
|
||||
|
||||
145
internal/iplibrary/ip_list_db.go
Normal file
145
internal/iplibrary/ip_list_db.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type IPListDB struct {
|
||||
db *sql.DB
|
||||
|
||||
itemTableName string
|
||||
deleteItemStmt *sql.Stmt
|
||||
insertItemStmt *sql.Stmt
|
||||
selectItemsStmt *sql.Stmt
|
||||
|
||||
dir string
|
||||
}
|
||||
|
||||
func NewIPListDB() (*IPListDB, error) {
|
||||
var db = &IPListDB{
|
||||
itemTableName: "ipItems",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
}
|
||||
|
||||
func (this *IPListDB) init() error {
|
||||
// 检查目录是否存在
|
||||
_, err := os.Stat(this.dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(this.dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("CACHE", "create cache dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+this.dir+"/ip_list.db?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
this.db = db
|
||||
|
||||
// 初始化数据库
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"listId" integer DEFAULT 0,
|
||||
"listType" varchar(32),
|
||||
"isGlobal" integer(1) DEFAULT 0,
|
||||
"type" varchar(16),
|
||||
"itemId" integer DEFAULT 0,
|
||||
"ipFrom" varchar(64) DEFAULT 0,
|
||||
"ipTo" varchar(64) DEFAULT 0,
|
||||
"expiredAt" integer DEFAULT 0,
|
||||
"eventLevel" varchar(32),
|
||||
"isDeleted" integer(1) DEFAULT 0,
|
||||
"version" integer DEFAULT 0,
|
||||
"nodeId" integer DEFAULT 0,
|
||||
"serverId" integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"itemId" ASC
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
|
||||
ON "` + this.itemTableName + `" (
|
||||
"expiredAt" ASC
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化SQL语句
|
||||
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
_, err := this.deleteItemStmt.Exec(item.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
|
||||
rows, err := this.selectItemsStmt.Query(offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
|
||||
var pbItem = &pb.IPItem{}
|
||||
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, pbItem)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *IPListDB) Close() error {
|
||||
if this.db != nil {
|
||||
_ = this.deleteItemStmt.Close()
|
||||
_ = this.insertItemStmt.Close()
|
||||
_ = this.selectItemsStmt.Close()
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
60
internal/iplibrary/ip_list_db_test.go
Normal file
60
internal/iplibrary/ip_list_db_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListDB_AddItem(t *testing.T) {
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.AddItem(&pb.IPItem{
|
||||
Id: 1,
|
||||
IpFrom: "192.168.1.101",
|
||||
IpTo: "",
|
||||
Version: 1024,
|
||||
ExpiredAt: time.Now().Unix(),
|
||||
Reason: "",
|
||||
ListId: 2,
|
||||
IsDeleted: true,
|
||||
Type: "ipv4",
|
||||
EventLevel: "error",
|
||||
ListType: "black",
|
||||
IsGlobal: true,
|
||||
CreatedAt: 0,
|
||||
NodeId: 11,
|
||||
ServerId: 22,
|
||||
SourceNodeId: 0,
|
||||
SourceServerId: 0,
|
||||
SourceHTTPFirewallPolicyId: 0,
|
||||
SourceHTTPFirewallRuleGroupId: 0,
|
||||
SourceHTTPFirewallRuleSetId: 0,
|
||||
SourceServer: nil,
|
||||
SourceHTTPFirewallPolicy: nil,
|
||||
SourceHTTPFirewallRuleGroup: nil,
|
||||
SourceHTTPFirewallRuleSet: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestIPListDB_ReadItems(t *testing.T) {
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
items, err := db.ReadItems(0, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(items, t)
|
||||
}
|
||||
55
internal/iplibrary/list_utils.go
Normal file
55
internal/iplibrary/list_utils.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
)
|
||||
|
||||
// AllowIP 检查IP是否被允许访问
|
||||
func AllowIP(ip string, serverId int64) bool {
|
||||
var ipLong = utils.IP2Long(ip)
|
||||
if ipLong == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// check white lists
|
||||
if GlobalWhiteIPList.Contains(ipLong) {
|
||||
return true
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindWhiteList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// check black lists
|
||||
if GlobalBlackIPList.Contains(ipLong) {
|
||||
return false
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindBlackList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowIPStrings 检查一组IP是否被允许访问
|
||||
func AllowIPStrings(ipStrings []string, serverId int64) bool {
|
||||
if len(ipStrings) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, ip := range ipStrings {
|
||||
isAllowed := AllowIP(ip, serverId)
|
||||
if !isAllowed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
20
internal/iplibrary/list_utils_test.go
Normal file
20
internal/iplibrary/list_utils_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPIsAllowed(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
t.Log(AllowIP("127.0.0.1", 0))
|
||||
t.Log(AllowIP("127.0.0.1", 23))
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func init() {
|
||||
// 初始化
|
||||
library, err := SharedManager.Load()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIBRARY", err.Error())
|
||||
remotelogs.ErrorObject("IP_LIBRARY", err)
|
||||
return
|
||||
}
|
||||
SharedLibrary = library
|
||||
|
||||
@@ -46,13 +46,13 @@ func (this *CountryManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.Error("COUNTRY_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("COUNTRY_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
@@ -63,7 +63,7 @@ func (this *CountryManager) Start() {
|
||||
for range ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("COUNTRY_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -22,9 +23,7 @@ func init() {
|
||||
|
||||
// IPListManager IP名单管理
|
||||
type IPListManager struct {
|
||||
// 缓存文件
|
||||
// 每行一个数据:id|from|to|expiredAt
|
||||
cacheFile string
|
||||
db *IPListDB
|
||||
|
||||
version int64
|
||||
pageSize int64
|
||||
@@ -35,22 +34,24 @@ type IPListManager struct {
|
||||
|
||||
func NewIPListManager() *IPListManager {
|
||||
return &IPListManager{
|
||||
cacheFile: Tea.Root + "/configs/ip_list.cache",
|
||||
pageSize: 1000,
|
||||
listMap: map[int64]*IPList{},
|
||||
pageSize: 500,
|
||||
listMap: map[int64]*IPList{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) Start() {
|
||||
// TODO 从缓存当中读取数据
|
||||
this.init()
|
||||
|
||||
// 第一次读取
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("IP_LIST_MANAGER", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
if Tea.IsTesting() {
|
||||
ticker = time.NewTicker(10 * time.Second)
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
ticker.Stop()
|
||||
})
|
||||
@@ -64,7 +65,7 @@ func (this *IPListManager) Start() {
|
||||
if err != nil {
|
||||
countErrors++
|
||||
|
||||
remotelogs.Error("IP_LIST_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("IP_LIST_MANAGER", err)
|
||||
|
||||
// 连续错误小于3次的我们立即重试
|
||||
if countErrors <= 3 {
|
||||
@@ -79,6 +80,31 @@ func (this *IPListManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) init() {
|
||||
// 从数据库中当中读取数据
|
||||
db, err := NewIPListDB()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
|
||||
} else {
|
||||
this.db = db
|
||||
|
||||
var offset int64 = 0
|
||||
var size int64 = 1000
|
||||
for {
|
||||
items, err := db.ReadItems(offset, size)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
|
||||
} else {
|
||||
if len(items) == 0 {
|
||||
break
|
||||
}
|
||||
this.processItems(items, false)
|
||||
}
|
||||
offset += int64(len(items))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) loop() error {
|
||||
for {
|
||||
hasNext, err := this.fetch()
|
||||
@@ -88,10 +114,9 @@ func (this *IPListManager) loop() error {
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// TODO 写入到缓存当中
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -111,11 +136,53 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
if len(items) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 保存到本地数据库
|
||||
if this.db != nil {
|
||||
for _, item := range items {
|
||||
err = this.db.AddItem(item)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "insert item to local database failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.processItems(items, true)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *IPListManager) FindList(listId int64) *IPList {
|
||||
this.locker.Lock()
|
||||
list, _ := this.listMap[listId]
|
||||
this.locker.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPListManager) processItems(items []*pb.IPItem, shouldExecute bool) {
|
||||
this.locker.Lock()
|
||||
var changedLists = map[*IPList]bool{}
|
||||
for _, item := range items {
|
||||
list, ok := this.listMap[item.ListId]
|
||||
if !ok {
|
||||
var list *IPList
|
||||
// TODO 实现节点专有List
|
||||
if item.ServerId > 0 { // 服务专有List
|
||||
switch item.ListType {
|
||||
case "black":
|
||||
list = SharedServerListManager.FindBlackList(item.ServerId, true)
|
||||
case "white":
|
||||
list = SharedServerListManager.FindWhiteList(item.ServerId, true)
|
||||
}
|
||||
} else if item.IsGlobal { // 全局List
|
||||
switch item.ListType {
|
||||
case "black":
|
||||
list = GlobalBlackIPList
|
||||
case "white":
|
||||
list = GlobalWhiteIPList
|
||||
}
|
||||
} else { // 其他List
|
||||
list = this.listMap[item.ListId]
|
||||
}
|
||||
if list == nil {
|
||||
list = NewIPList()
|
||||
this.listMap[item.ListId] = list
|
||||
}
|
||||
@@ -125,8 +192,13 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
if item.IsDeleted {
|
||||
list.Delete(item.Id)
|
||||
|
||||
// 从WAF名单中删除
|
||||
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId)
|
||||
|
||||
// 操作事件
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
if shouldExecute {
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
@@ -141,8 +213,10 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
})
|
||||
|
||||
// 事件操作
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
SharedActionManager.AddItem(item.ListType, item)
|
||||
if shouldExecute {
|
||||
SharedActionManager.DeleteItem(item.ListType, item)
|
||||
SharedActionManager.AddItem(item.ListType, item)
|
||||
}
|
||||
}
|
||||
|
||||
for changedList := range changedLists {
|
||||
@@ -151,13 +225,4 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
|
||||
this.locker.Unlock()
|
||||
this.version = items[len(items)-1].Version
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *IPListManager) FindList(listId int64) *IPList {
|
||||
this.locker.Lock()
|
||||
list, _ := this.listMap[listId]
|
||||
this.locker.Unlock()
|
||||
return list
|
||||
}
|
||||
|
||||
@@ -1,10 +1,36 @@
|
||||
package iplibrary
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPListManager_init(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
t.Log(manager.listMap)
|
||||
t.Log(SharedServerListManager.blackMap)
|
||||
logs.PrintAsJSON(GlobalBlackIPList.sortedItems, t)
|
||||
}
|
||||
|
||||
func TestIPListManager_check(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.init()
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
t.Log(SharedServerListManager.FindBlackList(23, true).Contains(utils.IP2Long("127.0.0.2")))
|
||||
t.Log(GlobalBlackIPList.Contains(utils.IP2Long("127.0.0.6")))
|
||||
}
|
||||
|
||||
func TestIPListManager_loop(t *testing.T) {
|
||||
manager := NewIPListManager()
|
||||
manager.pageSize = 2
|
||||
manager.Start()
|
||||
manager.pageSize = 10
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -50,13 +50,13 @@ func (this *ProvinceManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.Error("PROVINCE_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("PROVINCE_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
@@ -67,7 +67,7 @@ func (this *ProvinceManager) Start() {
|
||||
for range ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("PROVINCE_MANAGER", err.Error())
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
61
internal/iplibrary/server_list_manager.go
Normal file
61
internal/iplibrary/server_list_manager.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import "sync"
|
||||
|
||||
var SharedServerListManager = NewServerListManager()
|
||||
|
||||
// ServerListManager 服务相关名单
|
||||
type ServerListManager struct {
|
||||
whiteMap map[int64]*IPList // serverId => *List
|
||||
blackMap map[int64]*IPList // serverId => *List
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewServerListManager() *ServerListManager {
|
||||
return &ServerListManager{
|
||||
whiteMap: map[int64]*IPList{},
|
||||
blackMap: map[int64]*IPList{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ServerListManager) FindWhiteList(serverId int64, autoCreate bool) *IPList {
|
||||
this.locker.RLock()
|
||||
list, ok := this.whiteMap[serverId]
|
||||
this.locker.RUnlock()
|
||||
if ok {
|
||||
return list
|
||||
}
|
||||
|
||||
if autoCreate {
|
||||
list = NewIPList()
|
||||
this.locker.Lock()
|
||||
this.whiteMap[serverId] = list
|
||||
this.locker.Unlock()
|
||||
|
||||
return list
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *ServerListManager) FindBlackList(serverId int64, autoCreate bool) *IPList {
|
||||
this.locker.RLock()
|
||||
list, ok := this.blackMap[serverId]
|
||||
this.locker.RUnlock()
|
||||
if ok {
|
||||
return list
|
||||
}
|
||||
|
||||
if autoCreate {
|
||||
list = NewIPList()
|
||||
this.locker.Lock()
|
||||
this.blackMap[serverId] = list
|
||||
this.locker.Unlock()
|
||||
|
||||
return list
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func (this *Updater) Start() {
|
||||
for range ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIBRARY", err.Error())
|
||||
remotelogs.ErrorObject("IP_LIBRARY", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@@ -164,6 +165,8 @@ func (this *Task) Start() error {
|
||||
this.statsTicker = utils.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for this.statsTicker.Next() {
|
||||
var tr = trackers.Begin("[METRIC]DUMP_STATS_TO_LOCAL_DATABASE")
|
||||
|
||||
this.statsLocker.Lock()
|
||||
var statsMap = this.statsMap
|
||||
this.statsMap = map[string]*Stat{}
|
||||
@@ -175,6 +178,8 @@ func (this *Task) Start() error {
|
||||
remotelogs.Error("METRIC", "insert stat failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
tr.End()
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -182,7 +187,9 @@ func (this *Task) Start() error {
|
||||
this.cleanTicker = utils.NewTicker(24 * time.Hour)
|
||||
go func() {
|
||||
for this.cleanTicker.Next() {
|
||||
var tr = trackers.Begin("[METRIC]CLEAN_EXPIRED")
|
||||
err := this.CleanExpired()
|
||||
tr.End()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC", "clean expired stats failed: "+err.Error())
|
||||
}
|
||||
@@ -193,7 +200,9 @@ func (this *Task) Start() error {
|
||||
this.uploadTicker = utils.NewTicker(this.item.UploadDuration())
|
||||
go func() {
|
||||
for this.uploadTicker.Next() {
|
||||
var tr = trackers.Begin("[METRIC]UPLOAD_STATS")
|
||||
err := this.Upload(1 * time.Second)
|
||||
tr.End()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC", "upload stats failed: "+err.Error())
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (this *ValueQueue) Start() {
|
||||
// 这里单次循环就行,因为Loop里已经使用了Range通道
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("MONITOR_QUEUE", err.Error())
|
||||
remotelogs.ErrorObject("MONITOR_QUEUE", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,11 @@ func (this *ValueQueue) Loop() error {
|
||||
CreatedAt: value.CreatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("MONITOR", err.Error())
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Warn("MONITOR", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("MONITOR", err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
28
internal/monitor/value_queue_test.go
Normal file
28
internal/monitor/value_queue_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc/status"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValueQueue_RPC(t *testing.T) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rpcClient.NodeValueRPC().CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{})
|
||||
if err != nil {
|
||||
statusErr, ok:= status.FromError(err)
|
||||
if ok {
|
||||
logs.Println(statusErr.Code())
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
@@ -9,16 +9,18 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -45,7 +47,7 @@ func (this *APIStream) Start() {
|
||||
}
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", err.Error())
|
||||
remotelogs.Warn("API_STREAM", err.Error())
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
@@ -79,7 +81,7 @@ func (this *APIStream) loop() error {
|
||||
|
||||
for {
|
||||
if isQuiting {
|
||||
logs.Println("API_STREAM", "quit")
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -112,6 +114,8 @@ func (this *APIStream) loop() error {
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
||||
err = this.handleChangeAPINode(message)
|
||||
default:
|
||||
err = this.handleUnknownMessage(message)
|
||||
}
|
||||
@@ -571,6 +575,65 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改API地址
|
||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "read config error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var messageData = &messageconfigs.ChangeAPINodeMessage{}
|
||||
err = json.Unmarshal(message.DataJSON, messageData)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "unmarshal message failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = url.Parse(messageData.Addr)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "invalid new api node address: '"+messageData.Addr+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
config.RPC.Endpoints = []string{messageData.Addr}
|
||||
|
||||
// 保存到文件
|
||||
err = config.WriteFile(Tea.ConfigFile("api.yaml"))
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "save config file failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
this.replyOk(message.RequestId, "")
|
||||
|
||||
go func() {
|
||||
// 延后生效,防止变更前的API无法读取到状态
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient.Close()
|
||||
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' successfully")
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理未知消息
|
||||
func (this *APIStream) handleUnknownMessage(message *pb.NodeStreamMessage) error {
|
||||
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
|
||||
|
||||
@@ -4,6 +4,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
@@ -12,10 +13,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 流量统计
|
||||
var inTrafficBytes = uint64(0)
|
||||
var outTrafficBytes = uint64(0)
|
||||
|
||||
// 发送监控流量
|
||||
func init() {
|
||||
events.On(events.EventStart, func() {
|
||||
@@ -23,20 +20,20 @@ func init() {
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// 加入到数据队列中
|
||||
if inTrafficBytes > 0 {
|
||||
if teaconst.InTrafficBytes > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficIn, maps.Map{
|
||||
"total": inTrafficBytes,
|
||||
"total": teaconst.InTrafficBytes,
|
||||
})
|
||||
}
|
||||
if outTrafficBytes > 0 {
|
||||
if teaconst.OutTrafficBytes > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficOut, maps.Map{
|
||||
"total": outTrafficBytes,
|
||||
"total": teaconst.OutTrafficBytes,
|
||||
})
|
||||
}
|
||||
|
||||
// 重置数据
|
||||
atomic.StoreUint64(&inTrafficBytes, 0)
|
||||
atomic.StoreUint64(&outTrafficBytes, 0)
|
||||
atomic.StoreUint64(&teaconst.InTrafficBytes, 0)
|
||||
atomic.StoreUint64(&teaconst.OutTrafficBytes, 0)
|
||||
}
|
||||
}()
|
||||
})
|
||||
@@ -63,7 +60,7 @@ func NewClientConn(conn net.Conn, quickClose bool) net.Conn {
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
n, err = this.rawConn.Read(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&inTrafficBytes, uint64(n))
|
||||
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -71,7 +68,7 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
n, err = this.rawConn.Write(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&outTrafficBytes, uint64(n))
|
||||
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
@@ -29,9 +30,8 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
|
||||
if !iplibrary.AllowIP(ip, 0) || (!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -12,13 +14,15 @@ var sharedHTTPAccessLogQueue = NewHTTPAccessLogQueue()
|
||||
// HTTPAccessLogQueue HTTP访问日志队列
|
||||
type HTTPAccessLogQueue struct {
|
||||
queue chan *pb.HTTPAccessLog
|
||||
|
||||
rpcClient *rpc.RPCClient
|
||||
}
|
||||
|
||||
// NewHTTPAccessLogQueue 获取新对象
|
||||
func NewHTTPAccessLogQueue() *HTTPAccessLogQueue {
|
||||
// 队列中最大的值,超出此数量的访问日志会被丢弃
|
||||
// TODO 需要可以在界面中设置
|
||||
maxSize := 10000
|
||||
maxSize := 20000
|
||||
queue := &HTTPAccessLogQueue{
|
||||
queue: make(chan *pb.HTTPAccessLog, maxSize),
|
||||
}
|
||||
@@ -49,17 +53,31 @@ func (this *HTTPAccessLogQueue) Push(accessLog *pb.HTTPAccessLog) {
|
||||
|
||||
// 上传访问日志
|
||||
func (this *HTTPAccessLogQueue) loop() error {
|
||||
accessLogs := []*pb.HTTPAccessLog{}
|
||||
count := 0
|
||||
var accessLogs = []*pb.HTTPAccessLog{}
|
||||
var count = 0
|
||||
var timestamp int64
|
||||
var requestId = 1_000_000
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case accessLog := <-this.queue:
|
||||
var unixTime = utils.UnixTime()
|
||||
if unixTime > timestamp {
|
||||
requestId = 1_000_000
|
||||
timestamp = unixTime
|
||||
} else {
|
||||
requestId++
|
||||
}
|
||||
|
||||
// timestamp + requestId + nodeId
|
||||
accessLog.RequestId = strconv.FormatInt(unixTime, 10) + strconv.Itoa(requestId) + strconv.FormatInt(accessLog.NodeId, 10)
|
||||
|
||||
accessLogs = append(accessLogs, accessLog)
|
||||
count++
|
||||
|
||||
// 每次只提交 N 条访问日志,防止网络拥堵
|
||||
if count > 1000 {
|
||||
if count > 2000 {
|
||||
break Loop
|
||||
}
|
||||
default:
|
||||
@@ -72,12 +90,15 @@ Loop:
|
||||
}
|
||||
|
||||
// 发送到API
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
if this.rpcClient == nil {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.rpcClient = client
|
||||
}
|
||||
|
||||
_, err = client.HTTPAccessLogRPC().CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
_, err := this.rpcClient.HTTPAccessLogRPC().CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -86,13 +86,21 @@ type HTTPRequest struct {
|
||||
// 初始化
|
||||
func (this *HTTPRequest) init() {
|
||||
this.writer = NewHTTPWriter(this, this.RawWriter)
|
||||
this.web = &serverconfigs.HTTPWebConfig{IsOn: true}
|
||||
this.web = &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
}
|
||||
|
||||
// this.uri = this.RawReq.URL.RequestURI()
|
||||
// 之所以不使用RequestURI(),是不想让URL中的Path被Encode
|
||||
var urlPath = this.RawReq.URL.Path
|
||||
if this.Server.Web != nil && this.Server.Web.MergeSlashes {
|
||||
urlPath = utils.CleanPath(urlPath)
|
||||
this.web.MergeSlashes = true
|
||||
}
|
||||
if len(this.RawReq.URL.RawQuery) > 0 {
|
||||
this.uri = this.RawReq.URL.Path + "?" + this.RawReq.URL.RawQuery
|
||||
this.uri = urlPath + "?" + this.RawReq.URL.RawQuery
|
||||
} else {
|
||||
this.uri = this.RawReq.URL.Path
|
||||
this.uri = urlPath
|
||||
}
|
||||
|
||||
this.rawURI = this.uri
|
||||
@@ -137,9 +145,16 @@ func (this *HTTPRequest) Do() {
|
||||
}
|
||||
}
|
||||
|
||||
// 带宽限制
|
||||
if this.Server.BandwidthLimit != nil && this.Server.BandwidthLimit.IsOn && !this.Server.BandwidthLimit.IsEmpty() && this.Server.BandwidthLimitStatus != nil && this.Server.BandwidthLimitStatus.IsValid() {
|
||||
this.doBandwidthLimit()
|
||||
// 套餐
|
||||
if this.Server.UserPlan != nil && !this.Server.UserPlan.IsAvailable() {
|
||||
this.doPlanExpires()
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
|
||||
// 流量限制
|
||||
if this.Server.TrafficLimit != nil && this.Server.TrafficLimit.IsOn && !this.Server.TrafficLimit.IsEmpty() && this.Server.TrafficLimitStatus != nil && this.Server.TrafficLimitStatus.IsValid() {
|
||||
this.doTrafficLimit()
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
@@ -262,14 +277,15 @@ func (this *HTTPRequest) doEnd() {
|
||||
|
||||
// 流量统计
|
||||
// TODO 增加是否开启开关
|
||||
// TODO 增加Header统计,考虑从Conn中读取
|
||||
if this.Server != nil {
|
||||
if this.isCached {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
|
||||
} else {
|
||||
if this.isAttack {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes)
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
|
||||
} else {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -8,7 +9,11 @@ import (
|
||||
|
||||
// 主机地址快速跳转
|
||||
func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
fullURL := this.requestScheme() + "://" + this.Host + this.RawReq.URL.Path
|
||||
var urlPath = this.RawReq.URL.Path
|
||||
if this.web.MergeSlashes {
|
||||
urlPath = utils.CleanPath(urlPath)
|
||||
}
|
||||
fullURL := this.requestScheme() + "://" + this.Host + urlPath
|
||||
for _, u := range this.web.HostRedirects {
|
||||
if !u.IsOn {
|
||||
continue
|
||||
|
||||
@@ -3,14 +3,10 @@ package nodes
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var requestId int64 = 1_0000_0000_0000_0000
|
||||
|
||||
// 日志
|
||||
func (this *HTTPRequest) log() {
|
||||
if this.disableLog {
|
||||
@@ -85,7 +81,7 @@ func (this *HTTPRequest) log() {
|
||||
}
|
||||
|
||||
accessLog := &pb.HTTPAccessLog{
|
||||
RequestId: strconv.FormatInt(this.requestFromTime.UnixNano(), 10) + strconv.FormatInt(atomic.AddInt64(&requestId, 1), 10) + sharedNodeConfig.PaddedId(),
|
||||
RequestId: "",
|
||||
NodeId: sharedNodeConfig.Id,
|
||||
ServerId: this.Server.Id,
|
||||
RemoteAddr: this.requestRemoteAddr(true),
|
||||
|
||||
@@ -23,7 +23,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
|
||||
if page.Match(status) {
|
||||
if len(page.BodyType) == 0 || page.BodyType == shared.BodyTypeURL {
|
||||
if urlPrefixRegexp.MatchString(page.URL) {
|
||||
this.doURL(http.MethodGet, page.URL, "", page.NewStatus, true)
|
||||
var newStatus = page.NewStatus
|
||||
if newStatus <= 0 {
|
||||
newStatus = status
|
||||
}
|
||||
this.doURL(http.MethodGet, page.URL, "", newStatus, true)
|
||||
return true
|
||||
} else {
|
||||
file := Tea.Root + Tea.DS + page.URL
|
||||
|
||||
19
internal/nodes/http_request_plan_expires.go
Normal file
19
internal/nodes/http_request_plan_expires.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 套餐过期
|
||||
func (this *HTTPRequest) doPlanExpires() {
|
||||
this.tags = append(this.tags, "plan")
|
||||
|
||||
var statusCode = http.StatusNotFound
|
||||
this.processResponseHeaders(statusCode)
|
||||
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(serverconfigs.DefaultPlanExpireNoticePageBody)
|
||||
}
|
||||
@@ -39,8 +39,8 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
origin := this.reverseProxy.NextOrigin(requestCall)
|
||||
requestCall.CallResponseCallbacks(this.writer)
|
||||
if origin == nil {
|
||||
err := errors.New(this.requestPath() + ": no available backends for reverse proxy")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
|
||||
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
|
||||
// 处理Scheme
|
||||
if origin.Addr == nil {
|
||||
err := errors.New(this.requestPath() + ": origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
err := errors.New(this.requestFullURL() + ": origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway)
|
||||
return
|
||||
@@ -281,7 +281,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
|
||||
closeErr := resp.Body.Close()
|
||||
if closeErr != nil {
|
||||
if !this.canIgnore(err) {
|
||||
if !this.canIgnore(closeErr) {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", closeErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
)
|
||||
|
||||
// 带宽限制
|
||||
func (this *HTTPRequest) doBandwidthLimit() {
|
||||
var config = this.Server.BandwidthLimit
|
||||
// 流量限制
|
||||
func (this *HTTPRequest) doTrafficLimit() {
|
||||
var config = this.Server.TrafficLimit
|
||||
|
||||
this.tags = append(this.tags, "bandwidth")
|
||||
|
||||
@@ -19,6 +19,6 @@ func (this *HTTPRequest) doBandwidthLimit() {
|
||||
if len(config.NoticePageBody) != 0 {
|
||||
_, _ = this.writer.WriteString(config.NoticePageBody)
|
||||
} else {
|
||||
_, _ = this.writer.WriteString(serverconfigs.DefaultBandwidthLimitNoticePageBody)
|
||||
_, _ = this.writer.WriteString(serverconfigs.DefaultTrafficLimitNoticePageBody)
|
||||
}
|
||||
}
|
||||
@@ -26,8 +26,17 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
var remoteAddr = this.requestRemoteAddr(true)
|
||||
if !iplibrary.AllowIP(remoteAddr, this.Server.Id) {
|
||||
this.disableLog = true
|
||||
if conn != nil {
|
||||
_ = conn.(net.Conn).Close()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否在临时黑名单中
|
||||
var remoteAddr = this.WAFRemoteIP()
|
||||
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.Server.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
|
||||
this.disableLog = true
|
||||
if conn != nil {
|
||||
|
||||
@@ -552,11 +552,6 @@ func (this *HTTPWriter) prepareCache(size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
// 不支持Range
|
||||
if len(this.Header().Get("Content-Range")) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cachePolicy := this.req.Server.HTTPCachePolicy
|
||||
if cachePolicy == nil || !cachePolicy.IsOn {
|
||||
return
|
||||
@@ -567,17 +562,36 @@ func (this *HTTPWriter) prepareCache(size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
var addStatusHeader = this.req.web != nil && this.req.web.Cache != nil && this.req.web.Cache.AddStatusHeader
|
||||
|
||||
// 不支持Range
|
||||
if len(this.Header().Get("Content-Range")) > 0 {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, not supported Content-Range")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 如果允许 ChunkedEncoding,就无需尺寸的判断,因为此时的 size 为 -1
|
||||
if !cacheRef.AllowChunkedEncoding && size < 0 {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, ChunkedEncoding")
|
||||
}
|
||||
return
|
||||
}
|
||||
if size >= 0 && ((cacheRef.MaxSizeBytes() > 0 && size > cacheRef.MaxSizeBytes()) ||
|
||||
(cachePolicy.MaxSizeBytes() > 0 && size > cachePolicy.MaxSizeBytes()) || (cacheRef.MinSizeBytes() > size)) {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, Content-Length")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
if len(cacheRef.Status) > 0 && !lists.ContainsInt(cacheRef.Status, this.StatusCode()) {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, Status: "+types.String(this.StatusCode()))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -588,6 +602,9 @@ func (this *HTTPWriter) prepareCache(size int64) {
|
||||
values := strings.Split(cacheControl, ",")
|
||||
for _, value := range values {
|
||||
if cacheRef.ContainsCacheControl(strings.TrimSpace(value)) {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, Cache-Control: "+cacheControl)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -596,19 +613,29 @@ func (this *HTTPWriter) prepareCache(size int64) {
|
||||
|
||||
// Set-Cookie
|
||||
if cacheRef.SkipResponseSetCookie && len(this.writer.Header().Get("Set-Cookie")) > 0 {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, Set-Cookie")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 校验其他条件
|
||||
if cacheRef.Conds != nil && cacheRef.Conds.HasResponseConds() && !cacheRef.Conds.MatchResponse(this.req.Format) {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, ResponseConds")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 打开缓存写入
|
||||
storage := caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
|
||||
if storage == nil {
|
||||
if addStatusHeader {
|
||||
this.Header().Set("X-Cache", "BYPASS, Storage")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
this.cacheStorage = storage
|
||||
life := cacheRef.LifeSeconds()
|
||||
if life <= 60 { // 最小不能少于1分钟
|
||||
|
||||
@@ -5,17 +5,13 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"golang.org/x/net/http2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type BaseListener struct {
|
||||
serversLocker sync.RWMutex
|
||||
namedServersLocker sync.RWMutex
|
||||
namedServers map[string]*NamedServer // 域名 => server
|
||||
|
||||
Group *serverconfigs.ServerAddressGroup
|
||||
|
||||
countActiveConnections int64 // 当前活跃的连接数
|
||||
@@ -23,14 +19,11 @@ type BaseListener struct {
|
||||
|
||||
// Init 初始化
|
||||
func (this *BaseListener) Init() {
|
||||
this.namedServers = map[string]*NamedServer{}
|
||||
}
|
||||
|
||||
// Reset 清除既有配置
|
||||
func (this *BaseListener) Reset() {
|
||||
this.namedServersLocker.Lock()
|
||||
this.namedServers = map[string]*NamedServer{}
|
||||
this.namedServersLocker.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// CountActiveListeners 获取当前活跃连接数
|
||||
@@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
||||
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||
}
|
||||
return cert, nil
|
||||
},
|
||||
@@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
||||
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||
}
|
||||
return cert, nil
|
||||
},
|
||||
@@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
|
||||
// 根据域名匹配证书
|
||||
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
group := this.Group
|
||||
|
||||
if group == nil {
|
||||
@@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
|
||||
firstServer := group.FirstServer()
|
||||
firstServer := group.FirstTLSServer()
|
||||
if firstServer == nil {
|
||||
return nil, nil, errors.New("no server available")
|
||||
return nil, nil, errors.New("no tls server available")
|
||||
}
|
||||
sslConfig := firstServer.SSLPolicy()
|
||||
|
||||
@@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
|
||||
}
|
||||
return nil, nil, errors.New("no tls server name found")
|
||||
|
||||
}
|
||||
|
||||
// 通过代理服务域名配置匹配
|
||||
server, _ := this.findNamedServer(domain)
|
||||
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
// 搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
for _, server := range group.Servers {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
// TODO 需要思考这种情况下是否允许访问
|
||||
for _, server := range group.Servers() {
|
||||
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
@@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("[proxy]no server found for '" + domain + "'")
|
||||
return nil, nil, errors.New("no server found for '" + domain + "'")
|
||||
}
|
||||
|
||||
// 证书是否匹配
|
||||
@@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
return sslConfig, cert, nil
|
||||
}
|
||||
|
||||
if len(sslConfig.Certs) == 0 {
|
||||
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id))
|
||||
}
|
||||
|
||||
return sslConfig, sslConfig.FirstCert(), nil
|
||||
}
|
||||
|
||||
@@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
||||
}
|
||||
|
||||
// 如果没有找到,则匹配到第一个
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
group := this.Group
|
||||
currentServers := group.Servers
|
||||
currentServers := group.Servers()
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
@@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// 读取缓存
|
||||
this.namedServersLocker.RLock()
|
||||
namedServer, found := this.namedServers[name]
|
||||
if found {
|
||||
this.namedServersLocker.RUnlock()
|
||||
return namedServer.Server, namedServer.Name
|
||||
server := group.MatchServerName(name)
|
||||
if server != nil {
|
||||
return server, name
|
||||
}
|
||||
this.namedServersLocker.RUnlock()
|
||||
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
currentServers := group.Servers
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// 只记录N个记录,防止内存耗尽
|
||||
maxNamedServers := 100_0000
|
||||
|
||||
// 是否严格匹配域名
|
||||
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
// 如果只有一个server,则默认为这个
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 1 && !matchDomainStrictly {
|
||||
return currentServers[0], name
|
||||
}
|
||||
|
||||
// 精确查找
|
||||
for _, server := range currentServers {
|
||||
if server.MatchNameStrictly(name) {
|
||||
this.namedServersLocker.Lock()
|
||||
if len(this.namedServers) < maxNamedServers {
|
||||
this.namedServers[name] = &NamedServer{
|
||||
Name: name,
|
||||
Server: server,
|
||||
}
|
||||
}
|
||||
this.namedServersLocker.Unlock()
|
||||
return server, name
|
||||
}
|
||||
}
|
||||
|
||||
// 模糊查找
|
||||
for _, server := range currentServers {
|
||||
if matched := server.MatchName(name); matched {
|
||||
this.namedServersLocker.Lock()
|
||||
if len(this.namedServers) < maxNamedServers {
|
||||
this.namedServers[name] = &NamedServer{
|
||||
Name: name,
|
||||
Server: server,
|
||||
}
|
||||
}
|
||||
this.namedServersLocker.Unlock()
|
||||
return server, name
|
||||
}
|
||||
}
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 使用CNAME来查找服务
|
||||
// TODO 防止单IP随机生成域名攻击
|
||||
func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig {
|
||||
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
|
||||
if !sharedNodeConfig.SupportCNAME {
|
||||
return nil
|
||||
}
|
||||
@@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
this.serversLocker.Lock()
|
||||
defer this.serversLocker.Unlock()
|
||||
|
||||
group := this.Group
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
currentServers := group.Servers
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, server := range currentServers {
|
||||
if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) {
|
||||
return server
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return group.MatchServerCNAME(realName)
|
||||
}
|
||||
|
||||
36
internal/nodes/listener_base_test.go
Normal file
36
internal/nodes/listener_base_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBaseListener_FindServer(t *testing.T) {
|
||||
sharedNodeConfig = &nodeconfigs.NodeConfig{}
|
||||
|
||||
var listener = &BaseListener{namedServers: map[string]*NamedServer{}}
|
||||
listener.Group = &serverconfigs.ServerAddressGroup{}
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var server = &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Name: types.String(i) + ".hello.com",
|
||||
ServerNames: []*serverconfigs.ServerNameConfig{
|
||||
{Name: types.String(i) + ".hello.com"},
|
||||
},
|
||||
}
|
||||
_ = server.Init()
|
||||
listener.Group.Servers = append(listener.Group.Servers, server)
|
||||
}
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
t.Log(listener.findNamedServerMatched("855555.hello.com"))
|
||||
}
|
||||
@@ -37,18 +37,13 @@ type HTTPListener struct {
|
||||
}
|
||||
|
||||
func (this *HTTPListener) Serve() error {
|
||||
handler := http.NewServeMux()
|
||||
handler.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
this.handleHTTP(writer, request)
|
||||
})
|
||||
|
||||
this.addr = this.Group.Addr()
|
||||
this.isHTTP = this.Group.IsHTTP()
|
||||
this.isHTTPS = this.Group.IsHTTPS()
|
||||
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.addr,
|
||||
Handler: handler,
|
||||
Handler: this,
|
||||
ReadHeaderTimeout: 2 * time.Second, // TODO 改成可以配置
|
||||
IdleTimeout: 2 * time.Minute, // TODO 改成可以配置
|
||||
ErrorLog: httpErrorLogger,
|
||||
@@ -118,8 +113,8 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.Reset()
|
||||
}
|
||||
|
||||
// 处理HTTP请求
|
||||
func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
// ServerHTTP 处理HTTP请求
|
||||
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
// 域名
|
||||
reqHost := rawReq.Host
|
||||
|
||||
@@ -157,7 +152,7 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http
|
||||
|
||||
server, serverName := this.findNamedServer(domain)
|
||||
if server == nil {
|
||||
server = this.findServerWithCname(domain)
|
||||
server = this.findServerWithCNAME(domain)
|
||||
if server == nil {
|
||||
// 严格匹配域名模式下,我们拒绝用户访问
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
|
||||
@@ -65,17 +65,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
var serverName = tlsConn.ConnectionState().ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
recordStat = true
|
||||
}
|
||||
}
|
||||
|
||||
// 统计
|
||||
if !recordStat {
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +125,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0)
|
||||
if firstServer != nil {
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
closer()
|
||||
@@ -162,7 +164,7 @@ func (this *TCPListener) Close() error {
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -175,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
}
|
||||
conn, err = OriginConnect(origin, remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
} else {
|
||||
return
|
||||
|
||||
@@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
originConn, err := this.connectOrigin(this.reverseProxy, addr)
|
||||
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||
continue
|
||||
@@ -64,7 +64,7 @@ func (this *UDPListener) Serve() error {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
|
||||
continue
|
||||
}
|
||||
conn = NewUDPConn(firstServer.Id, addr, this.Listener, originConn.(*net.UDPConn))
|
||||
conn = NewUDPConn(firstServer, addr, this.Listener, originConn.(*net.UDPConn))
|
||||
this.connLocker.Lock()
|
||||
this.connMap[addr.String()] = conn
|
||||
this.connLocker.Unlock()
|
||||
@@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.reverseProxy = firstServer.ReverseProxy
|
||||
}
|
||||
|
||||
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
}
|
||||
conn, err = OriginConnect(origin, remoteAddr.String())
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
} else {
|
||||
// PROXY Protocol
|
||||
@@ -174,7 +174,7 @@ type UDPConn struct {
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||
conn := &UDPConn{
|
||||
addr: addr,
|
||||
proxyConn: proxyConn,
|
||||
@@ -184,7 +184,9 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon
|
||||
}
|
||||
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(serverId, "", 0, 0, 1, 0, 0, 0)
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
go func() {
|
||||
buffer := bytePool32k.Get()
|
||||
@@ -203,7 +205,9 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
stats.SharedTrafficStatManager.Add(serverId, "", int64(n), 0, 0, 0, 0, 0)
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
conn.isOk = false
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package nodes
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
|
||||
// 域名和服务映射
|
||||
type NamedServer struct {
|
||||
Name string // 匹配后的域名
|
||||
Server *serverconfigs.ServerConfig // 匹配后的服务配置
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
@@ -15,7 +16,9 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/go-yaml/yaml"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
@@ -89,14 +92,14 @@ func (this *Node) Start() {
|
||||
}
|
||||
|
||||
// 读取API配置
|
||||
err = this.syncConfig()
|
||||
err = this.syncConfig(0)
|
||||
if err != nil {
|
||||
_, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
// 无本地数据时,会尝试多次读取
|
||||
tryTimes := 0
|
||||
for {
|
||||
err := this.syncConfig()
|
||||
err := this.syncConfig(0)
|
||||
if err != nil {
|
||||
tryTimes++
|
||||
|
||||
@@ -128,6 +131,7 @@ func (this *Node) Start() {
|
||||
remotelogs.Error("NODE", "start failed: read node config failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
teaconst.NodeId = nodeConfig.Id
|
||||
err = nodeConfig.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init node config failed: "+err.Error())
|
||||
@@ -229,6 +233,9 @@ func (this *Node) InstallSystemService() error {
|
||||
|
||||
// 循环
|
||||
func (this *Node) loop() error {
|
||||
var tr = trackers.Begin("CHECK_NODE_CONFIG_CHANGES")
|
||||
defer tr.End()
|
||||
|
||||
// 检查api.yaml是否存在
|
||||
apiConfigFile := Tea.ConfigFile("api.yaml")
|
||||
_, err := os.Stat(apiConfigFile)
|
||||
@@ -261,7 +268,11 @@ func (this *Node) loop() error {
|
||||
return err
|
||||
}
|
||||
case "configChanged":
|
||||
err := this.syncConfig()
|
||||
if !task.IsPrimary {
|
||||
// 我们等等主节点配置准备完毕
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
err := this.syncConfig(task.Version)
|
||||
if err != nil {
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
@@ -287,7 +298,7 @@ func (this *Node) loop() error {
|
||||
}
|
||||
|
||||
// 读取API配置
|
||||
func (this *Node) syncConfig() error {
|
||||
func (this *Node) syncConfig(taskVersion int64) error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
@@ -318,7 +329,9 @@ func (this *Node) syncConfig() error {
|
||||
|
||||
// TODO 这里考虑只同步版本号有变更的
|
||||
configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{
|
||||
Version: -1, // 更新所有版本
|
||||
Version: -1, // 更新所有版本
|
||||
Compress: true,
|
||||
NodeTaskVersion: taskVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("read config from rpc failed: " + err.Error())
|
||||
@@ -326,14 +339,32 @@ func (this *Node) syncConfig() error {
|
||||
if !configResp.IsChanged {
|
||||
return nil
|
||||
}
|
||||
nodeConfigUpdatedAt = time.Now().Unix()
|
||||
|
||||
configJSON := configResp.NodeJSON
|
||||
if configResp.IsCompressed {
|
||||
var reader = brotli.NewReader(bytes.NewReader(configJSON))
|
||||
var configBuf = &bytes.Buffer{}
|
||||
var buf = make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
configBuf.Write(buf[:n])
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
configJSON = configBuf.Bytes()
|
||||
}
|
||||
|
||||
nodeConfigUpdatedAt = time.Now().Unix()
|
||||
|
||||
nodeConfig := &nodeconfigs.NodeConfig{}
|
||||
err = json.Unmarshal(configJSON, nodeConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode config failed: " + err.Error())
|
||||
}
|
||||
teaconst.NodeId = nodeConfig.Id
|
||||
|
||||
// 写入到文件中
|
||||
err = nodeConfig.Save()
|
||||
@@ -411,7 +442,7 @@ func (this *Node) startSyncTimer() {
|
||||
continue
|
||||
}
|
||||
case <-nodeConfigChangedNotify:
|
||||
err := this.syncConfig()
|
||||
err := this.syncConfig(0)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "sync config error: "+err.Error())
|
||||
continue
|
||||
@@ -496,6 +527,16 @@ func (this *Node) listenSock() error {
|
||||
"pid": os.Getpid(),
|
||||
},
|
||||
})
|
||||
case "info":
|
||||
exePath, _ := os.Executable()
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "info",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
"version": teaconst.Version,
|
||||
"path": exePath,
|
||||
},
|
||||
})
|
||||
case "stop":
|
||||
_ = cmd.ReplyOk()
|
||||
|
||||
@@ -519,6 +560,12 @@ func (this *Node) listenSock() error {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}()
|
||||
case "trackers":
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]interface{}{
|
||||
"labels": trackers.SharedManager.Labels(),
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
@@ -58,6 +59,9 @@ func (this *NodeStatusExecutor) update() {
|
||||
return
|
||||
}
|
||||
|
||||
var tr = trackers.Begin("UPLOAD_NODE_STATUS")
|
||||
defer tr.End()
|
||||
|
||||
status := &nodeconfigs.NodeStatus{}
|
||||
status.BuildVersion = teaconst.Version
|
||||
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
|
||||
@@ -68,8 +72,8 @@ func (this *NodeStatusExecutor) update() {
|
||||
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
|
||||
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()
|
||||
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
|
||||
status.TrafficInBytes = inTrafficBytes
|
||||
status.TrafficOutBytes = outTrafficBytes
|
||||
status.TrafficInBytes = teaconst.InTrafficBytes
|
||||
status.TrafficOutBytes = teaconst.OutTrafficBytes
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemConnections, maps.Map{
|
||||
@@ -79,11 +83,26 @@ func (this *NodeStatusExecutor) update() {
|
||||
hostname, _ := os.Hostname()
|
||||
status.Hostname = hostname
|
||||
|
||||
var cpuTR = tr.Begin("cpu")
|
||||
this.updateCPU(status)
|
||||
cpuTR.End()
|
||||
|
||||
var memTR = tr.Begin("memory")
|
||||
this.updateMem(status)
|
||||
memTR.End()
|
||||
|
||||
var loadTR = tr.Begin("load")
|
||||
this.updateLoad(status)
|
||||
loadTR.End()
|
||||
|
||||
var diskTR = tr.Begin("disk")
|
||||
this.updateDisk(status)
|
||||
diskTR.End()
|
||||
|
||||
var cacheSpaceTR = tr.Begin("cache space")
|
||||
this.updateCacheSpace(status)
|
||||
cacheSpaceTR.End()
|
||||
|
||||
status.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// 发送数据
|
||||
@@ -101,7 +120,11 @@ func (this *NodeStatusExecutor) update() {
|
||||
StatusJSON: jsonData,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Warn("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -60,6 +61,9 @@ func (this *OriginStateManager) Loop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var tr = trackers.Begin("CHECK_ORIGIN_STATES")
|
||||
defer tr.End()
|
||||
|
||||
var currentStates = []*OriginState{}
|
||||
this.locker.Lock()
|
||||
for originId, state := range this.stateMap {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc"
|
||||
@@ -53,6 +54,9 @@ func (this *SyncAPINodesTask) Start() {
|
||||
}
|
||||
|
||||
func (this *SyncAPINodesTask) Loop() error {
|
||||
var tr = trackers.Begin("SYNC_API_NODES")
|
||||
defer tr.End()
|
||||
|
||||
// 获取所有可用的节点
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
@@ -64,7 +68,7 @@ func (this *SyncAPINodesTask) Loop() error {
|
||||
}
|
||||
|
||||
newEndpoints := []string{}
|
||||
for _, node := range resp.Nodes {
|
||||
for _, node := range resp.ApiNodes {
|
||||
if !node.IsOn {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
policy.Mode = firewallconfigs.FirewallModeDefend
|
||||
}
|
||||
w := &waf.WAF{
|
||||
Id: strconv.FormatInt(policy.Id, 10),
|
||||
Id: policy.Id,
|
||||
IsOn: policy.IsOn,
|
||||
Name: policy.Name,
|
||||
Mode: policy.Mode,
|
||||
@@ -71,7 +71,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
if policy.Inbound != nil && policy.Inbound.IsOn {
|
||||
for _, group := range policy.Inbound.Groups {
|
||||
g := &waf.RuleGroup{
|
||||
Id: strconv.FormatInt(group.Id, 10),
|
||||
Id: group.Id,
|
||||
IsOn: group.IsOn,
|
||||
Name: group.Name,
|
||||
Description: group.Description,
|
||||
@@ -82,7 +82,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Id: set.Id,
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
@@ -126,7 +126,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
if policy.Outbound != nil && policy.Outbound.IsOn {
|
||||
for _, group := range policy.Outbound.Groups {
|
||||
g := &waf.RuleGroup{
|
||||
Id: strconv.FormatInt(group.Id, 10),
|
||||
Id: group.Id,
|
||||
IsOn: group.IsOn,
|
||||
Name: group.Name,
|
||||
Description: group.Description,
|
||||
@@ -137,7 +137,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Id: set.Id,
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/cespare/xxhash"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -16,7 +19,9 @@ func init() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
var tr = trackers.Begin("UPLOAD_REMOTE_LOGS")
|
||||
err := uploadLogs()
|
||||
tr.End()
|
||||
if err != nil {
|
||||
logs.Println("[LOG]" + err.Error())
|
||||
}
|
||||
@@ -93,6 +98,18 @@ func Error(tag string, description string) {
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorObject 打印错误对象
|
||||
func ErrorObject(tag string, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if rpc.IsConnError(err) {
|
||||
Warn(tag, err.Error())
|
||||
} else {
|
||||
Error(tag, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// ServerError 打印服务相关错误信息
|
||||
func ServerError(serverId int64, tag string, description string) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
@@ -144,11 +161,33 @@ func ServerSuccess(serverId int64, tag string, description string) {
|
||||
// 上传日志
|
||||
func uploadLogs() error {
|
||||
logList := []*pb.NodeLog{}
|
||||
|
||||
const hashSize = 5
|
||||
var hashList = []uint64{}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case log := <-logChan:
|
||||
logList = append(logList, log)
|
||||
// 是否已存在
|
||||
var hash = xxhash.Sum64String(types.String(log.ServerId) + "_" + log.Description)
|
||||
var found = false
|
||||
for _, h := range hashList {
|
||||
if h == hash {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 加入
|
||||
if !found {
|
||||
hashList = append(hashList, hash)
|
||||
if len(hashList) > hashSize {
|
||||
hashList = hashList[1:]
|
||||
}
|
||||
|
||||
logList = append(logList, log)
|
||||
}
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ type RPCClient struct {
|
||||
apiConfig *configs.APIConfig
|
||||
conns []*grpc.ClientConn
|
||||
|
||||
locker sync.Mutex
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
@@ -177,15 +177,23 @@ func (this *RPCClient) ClusterContext(clusterId string, clusterSecret string) co
|
||||
|
||||
// Close 关闭连接
|
||||
func (this *RPCClient) Close() {
|
||||
this.locker.Lock()
|
||||
|
||||
for _, conn := range this.conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// UpdateConfig 修改配置
|
||||
func (this *RPCClient) UpdateConfig(config *configs.APIConfig) error {
|
||||
this.apiConfig = config
|
||||
return this.init()
|
||||
|
||||
this.locker.Lock()
|
||||
err := this.init()
|
||||
this.locker.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化
|
||||
|
||||
@@ -2,12 +2,15 @@ package rpc
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var sharedRPC *RPCClient = nil
|
||||
var locker = &sync.Mutex{}
|
||||
|
||||
// SharedRPC RPC对象
|
||||
func SharedRPC() (*RPCClient, error) {
|
||||
locker.Lock()
|
||||
defer locker.Unlock()
|
||||
@@ -28,3 +31,18 @@ func SharedRPC() (*RPCClient, error) {
|
||||
sharedRPC = client
|
||||
return sharedRPC, nil
|
||||
}
|
||||
|
||||
// IsConnError 是否为连接错误
|
||||
func IsConnError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为连接错误
|
||||
statusErr, ok := status.FromError(err)
|
||||
if ok {
|
||||
return statusErr.Code() == codes.Unavailable
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
@@ -81,13 +82,23 @@ func (this *HTTPRequestStatManager) Start() {
|
||||
for range loopTicker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_STAT_MANAGER", err.Error())
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_STAT_MANAGER", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("HTTP_REQUEST_STAT_MANAGER", err.Error())
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-uploadTicker.C:
|
||||
var tr = trackers.Begin("UPLOAD_REQUEST_STATS")
|
||||
err := this.Upload()
|
||||
tr.End()
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_STAT_MANAGER", "upload failed: "+err.Error())
|
||||
if !rpc.IsConnError(err) {
|
||||
remotelogs.Error("HTTP_REQUEST_STAT_MANAGER", "upload failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Warn("HTTP_REQUEST_STAT_MANAGER", "upload failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -166,10 +177,10 @@ Loop:
|
||||
if iplibrary.SharedLibrary != nil {
|
||||
result, err := iplibrary.SharedLibrary.Lookup(ip)
|
||||
if err == nil && result != nil {
|
||||
this.cityMap[serverId+"@"+result.Country+"@"+result.Province+"@"+result.City] ++
|
||||
this.cityMap[serverId+"@"+result.Country+"@"+result.Province+"@"+result.City]++
|
||||
|
||||
if len(result.ISP) > 0 {
|
||||
this.providerMap[serverId+"@"+result.ISP] ++
|
||||
this.providerMap[serverId+"@"+result.ISP]++
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,7 +208,7 @@ Loop:
|
||||
if dotIndex > -1 {
|
||||
browserVersion = browserVersion[:dotIndex]
|
||||
}
|
||||
this.browserMap[serverId+"@"+browser+"@"+browserVersion] ++
|
||||
this.browserMap[serverId+"@"+browser+"@"+browserVersion]++
|
||||
}
|
||||
case firewallRuleGroupString := <-this.firewallRuleGroupChan:
|
||||
this.dailyFirewallRuleGroupMap[firewallRuleGroupString]++
|
||||
|
||||
@@ -20,12 +20,14 @@ import (
|
||||
var SharedTrafficStatManager = NewTrafficStatManager()
|
||||
|
||||
type TrafficItem struct {
|
||||
Bytes int64
|
||||
CachedBytes int64
|
||||
CountRequests int64
|
||||
CountCachedRequests int64
|
||||
CountAttackRequests int64
|
||||
AttackBytes int64
|
||||
Bytes int64
|
||||
CachedBytes int64
|
||||
CountRequests int64
|
||||
CountCachedRequests int64
|
||||
CountAttackRequests int64
|
||||
AttackBytes int64
|
||||
PlanId int64
|
||||
CheckingTrafficLimit bool
|
||||
}
|
||||
|
||||
// TrafficStatManager 区域流量统计
|
||||
@@ -80,13 +82,17 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
|
||||
for range ticker.C {
|
||||
err := this.Upload()
|
||||
if err != nil {
|
||||
remotelogs.Error("TRAFFIC_STAT_MANAGER", "upload stats failed: "+err.Error())
|
||||
if !rpc.IsConnError(err) {
|
||||
remotelogs.Error("TRAFFIC_STAT_MANAGER", "upload stats failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Warn("TRAFFIC_STAT_MANAGER", "upload stats failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加流量
|
||||
func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
|
||||
func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
|
||||
if bytes == 0 && countRequests == 0 {
|
||||
return
|
||||
}
|
||||
@@ -110,6 +116,8 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
|
||||
item.CountCachedRequests += countCachedRequests
|
||||
item.CountAttackRequests += countAttacks
|
||||
item.AttackBytes += attackBytes
|
||||
item.CheckingTrafficLimit = checkingTrafficLimit
|
||||
item.PlanId = planId
|
||||
|
||||
// 单个域名流量
|
||||
var domainKey = strconv.FormatInt(timestamp, 10) + "@" + strconv.FormatInt(serverId, 10) + "@" + domain
|
||||
@@ -160,15 +168,17 @@ func (this *TrafficStatManager) Upload() error {
|
||||
}
|
||||
|
||||
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
|
||||
ServerId: serverId,
|
||||
RegionId: config.RegionId,
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CreatedAt: timestamp,
|
||||
ServerId: serverId,
|
||||
RegionId: config.RegionId,
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CheckTrafficLimiting: item.CheckingTrafficLimit,
|
||||
PlanId: item.PlanId,
|
||||
CreatedAt: timestamp,
|
||||
})
|
||||
}
|
||||
if len(pbServerStats) == 0 {
|
||||
|
||||
28
internal/trackers/label.go
Normal file
28
internal/trackers/label.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package trackers
|
||||
|
||||
import "time"
|
||||
|
||||
type tracker struct {
|
||||
label string
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
func Begin(label string) *tracker {
|
||||
return &tracker{label: label, startTime: time.Now()}
|
||||
}
|
||||
|
||||
func Run(label string, f func()) {
|
||||
var tr = Begin(label)
|
||||
f()
|
||||
tr.End()
|
||||
}
|
||||
|
||||
func (this *tracker) End() {
|
||||
SharedManager.Add(this.label, time.Since(this.startTime).Seconds()*1000)
|
||||
}
|
||||
|
||||
func (this *tracker) Begin(subLabel string) *tracker {
|
||||
return Begin(this.label + ":" + subLabel)
|
||||
}
|
||||
47
internal/trackers/manager.go
Normal file
47
internal/trackers/manager.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package trackers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SharedManager = NewManager()
|
||||
|
||||
type Manager struct {
|
||||
m map[string][]float64 // label => time costs ms
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{m: map[string][]float64{}}
|
||||
}
|
||||
|
||||
func (this *Manager) Add(label string, costMs float64) {
|
||||
this.locker.Lock()
|
||||
costs, ok := this.m[label]
|
||||
if ok {
|
||||
costs = append(costs, costMs)
|
||||
if len(costs) > 5 { // 只取最近的N条
|
||||
costs = costs[1:]
|
||||
}
|
||||
this.m[label] = costs
|
||||
} else {
|
||||
this.m[label] = []float64{costMs}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *Manager) Labels() map[string]float64 {
|
||||
var result = map[string]float64{}
|
||||
this.locker.Lock()
|
||||
for label, costs := range this.m {
|
||||
var sum float64
|
||||
for _, cost := range costs {
|
||||
sum += cost
|
||||
}
|
||||
result[label] = sum / float64(len(costs))
|
||||
}
|
||||
this.locker.Unlock()
|
||||
return result
|
||||
}
|
||||
47
internal/trackers/manager_test.go
Normal file
47
internal/trackers/manager_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package trackers
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
{
|
||||
var tr = Begin("a")
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("a")
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("a")
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("a")
|
||||
time.Sleep(3 * time.Millisecond)
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("a")
|
||||
time.Sleep(4 * time.Millisecond)
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("a")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
tr.End()
|
||||
}
|
||||
{
|
||||
var tr = Begin("b")
|
||||
tr.End()
|
||||
}
|
||||
|
||||
logs.PrintAsJSON(SharedManager.Labels(), t)
|
||||
}
|
||||
172
internal/utils/free_hours_manager.go
Normal file
172
internal/utils/free_hours_manager.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedFreeHoursManager = NewFreeHoursManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go SharedFreeHoursManager.Start()
|
||||
})
|
||||
}
|
||||
|
||||
// FreeHoursManager 计算节点空闲时间
|
||||
// 以便于我们在空闲时间执行高强度的任务,如果清理缓存等
|
||||
type FreeHoursManager struct {
|
||||
dayTrafficMap map[int][24]uint64 // day => [ traffic bytes ]
|
||||
lastBytes uint64
|
||||
|
||||
freeHours []int
|
||||
count int
|
||||
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
func NewFreeHoursManager() *FreeHoursManager {
|
||||
return &FreeHoursManager{dayTrafficMap: map[int][24]uint64{}, count: 3}
|
||||
}
|
||||
|
||||
func (this *FreeHoursManager) Start() {
|
||||
var ticker = time.NewTicker(30 * time.Minute)
|
||||
for range ticker.C {
|
||||
this.Update(atomic.LoadUint64(&teaconst.InTrafficBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FreeHoursManager) Update(bytes uint64) {
|
||||
if this.count <= 0 {
|
||||
this.count = 3
|
||||
}
|
||||
|
||||
if this.lastBytes == 0 {
|
||||
this.lastBytes = bytes
|
||||
} else {
|
||||
// 记录流量
|
||||
var deltaBytes = bytes - this.lastBytes
|
||||
var now = time.Now()
|
||||
var day = now.Day()
|
||||
var hour = now.Hour()
|
||||
traffic, ok := this.dayTrafficMap[day]
|
||||
if ok {
|
||||
traffic[hour] += deltaBytes
|
||||
} else {
|
||||
var traffic = [24]uint64{}
|
||||
traffic[hour] += deltaBytes
|
||||
this.dayTrafficMap[day] = traffic
|
||||
}
|
||||
|
||||
this.lastBytes = bytes
|
||||
|
||||
// 计算空闲时间
|
||||
var result = [24]uint64{}
|
||||
var hasData = false
|
||||
for trafficDay, trafficArray := range this.dayTrafficMap {
|
||||
// 当天的不算
|
||||
if trafficDay == day {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查看最近5天的
|
||||
if (day > trafficDay && day-trafficDay <= 5) || (day < trafficDay && trafficDay-day >= 26) {
|
||||
var weights = this.sortUintArrayWeights(trafficArray)
|
||||
for k, v := range weights {
|
||||
result[k] += v
|
||||
}
|
||||
hasData = true
|
||||
}
|
||||
}
|
||||
if hasData {
|
||||
var freeHours = this.sortUintArrayIndexes(result)
|
||||
this.locker.Lock()
|
||||
this.freeHours = freeHours[:this.count] // 取前N个小时作为空闲时间
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FreeHoursManager) IsFreeHour() bool {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if len(this.freeHours) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var hour = time.Now().Hour()
|
||||
for _, h := range this.freeHours {
|
||||
if h == hour {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 对数组进行排序,并返回权重
|
||||
func (this *FreeHoursManager) sortUintArrayWeights(arr [24]uint64) [24]uint64 {
|
||||
var l = []map[string]interface{}{}
|
||||
for k, v := range arr {
|
||||
l = append(l, map[string]interface{}{
|
||||
"k": k,
|
||||
"v": v,
|
||||
})
|
||||
}
|
||||
sort.Slice(l, func(i, j int) bool {
|
||||
var m1 = l[i]
|
||||
var v1 = m1["v"].(uint64)
|
||||
|
||||
var m2 = l[j]
|
||||
var v2 = m2["v"].(uint64)
|
||||
|
||||
return v1 < v2
|
||||
})
|
||||
|
||||
var result = [24]uint64{}
|
||||
for k, v := range l {
|
||||
if k < this.count {
|
||||
k = 0
|
||||
} else {
|
||||
k = 1
|
||||
}
|
||||
result[v["k"].(int)] = v["v"].(uint64)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// 对数组进行排序,并返回索引
|
||||
func (this *FreeHoursManager) sortUintArrayIndexes(arr [24]uint64) [24]int {
|
||||
var l = []map[string]interface{}{}
|
||||
for k, v := range arr {
|
||||
l = append(l, map[string]interface{}{
|
||||
"k": k,
|
||||
"v": v,
|
||||
})
|
||||
}
|
||||
sort.Slice(l, func(i, j int) bool {
|
||||
var m1 = l[i]
|
||||
var v1 = m1["v"].(uint64)
|
||||
|
||||
var m2 = l[j]
|
||||
var v2 = m2["v"].(uint64)
|
||||
|
||||
return v1 < v2
|
||||
})
|
||||
|
||||
var result = [24]int{}
|
||||
var i = 0
|
||||
for _, v := range l {
|
||||
result[i] = v["k"].(int)
|
||||
i++
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
49
internal/utils/free_hours_manager_test.go
Normal file
49
internal/utils/free_hours_manager_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFreeHoursManager_Update(t *testing.T) {
|
||||
var manager = NewFreeHoursManager()
|
||||
manager.Update(111)
|
||||
|
||||
manager.dayTrafficMap[1] = [24]uint64{1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1}
|
||||
manager.dayTrafficMap[2] = [24]uint64{0, 0, 1, 0, 1, 1, 1, 0, 0}
|
||||
manager.dayTrafficMap[3] = [24]uint64{0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1}
|
||||
manager.dayTrafficMap[4] = [24]uint64{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}
|
||||
manager.dayTrafficMap[5] = [24]uint64{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}
|
||||
manager.dayTrafficMap[6] = [24]uint64{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1}
|
||||
manager.dayTrafficMap[7] = [24]uint64{}
|
||||
manager.dayTrafficMap[8] = [24]uint64{}
|
||||
manager.dayTrafficMap[9] = [24]uint64{}
|
||||
manager.dayTrafficMap[10] = [24]uint64{1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
|
||||
manager.dayTrafficMap[11] = [24]uint64{1}
|
||||
manager.dayTrafficMap[12] = [24]uint64{1}
|
||||
manager.dayTrafficMap[13] = [24]uint64{1}
|
||||
manager.dayTrafficMap[14] = [24]uint64{}
|
||||
manager.dayTrafficMap[15] = [24]uint64{}
|
||||
manager.dayTrafficMap[16] = [24]uint64{}
|
||||
manager.dayTrafficMap[25] = [24]uint64{}
|
||||
manager.dayTrafficMap[26] = [24]uint64{}
|
||||
manager.dayTrafficMap[27] = [24]uint64{}
|
||||
manager.dayTrafficMap[28] = [24]uint64{}
|
||||
manager.dayTrafficMap[29] = [24]uint64{}
|
||||
manager.dayTrafficMap[30] = [24]uint64{}
|
||||
manager.dayTrafficMap[31] = [24]uint64{}
|
||||
|
||||
var before = time.Now()
|
||||
manager.Update(222)
|
||||
t.Log(manager.freeHours)
|
||||
t.Log(manager.IsFreeHour())
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}
|
||||
|
||||
func TestFreeHoursManager_SortArray(t *testing.T) {
|
||||
var manager = NewFreeHoursManager()
|
||||
t.Log(manager.sortUintArrayWeights([24]uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 109, 10, 11, 12, 130, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}))
|
||||
t.Log(manager.sortUintArrayIndexes([24]uint64{1, 2, 3, 5, 4, 0, 100}))
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package utils
|
||||
|
||||
// 清理Path中的多余的字符
|
||||
// CleanPath 清理Path中的多余的字符
|
||||
func CleanPath(path string) string {
|
||||
l := len(path)
|
||||
if l == 0 {
|
||||
@@ -9,6 +9,10 @@ func CleanPath(path string) string {
|
||||
result := []byte{'/'}
|
||||
isSlash := true
|
||||
for i := 0; i < l; i++ {
|
||||
if path[i] == '?' {
|
||||
result = append(result, path[i:]...)
|
||||
break
|
||||
}
|
||||
if path[i] == '\\' || path[i] == '/' {
|
||||
if !isSlash {
|
||||
isSlash = true
|
||||
@@ -21,4 +25,3 @@ func CleanPath(path string) string {
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,11 @@ func TestCleanPath(t *testing.T) {
|
||||
a.IsTrue(CleanPath("/hello////world") == "/hello/world")
|
||||
}
|
||||
|
||||
func TestCleanPath_Args(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
a.IsTrue(CleanPath("/hello/world?base=///////") == "/hello/world?base=///////")
|
||||
}
|
||||
|
||||
func BenchmarkCleanPath(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = CleanPath("/hello///world/very/long/very//long")
|
||||
|
||||
@@ -6,6 +6,20 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRawTicker(t *testing.T) {
|
||||
var ticker = time.NewTicker(2 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
t.Log("tick")
|
||||
}
|
||||
t.Log("stop")
|
||||
}()
|
||||
|
||||
time.Sleep(6 * time.Second)
|
||||
ticker.Stop()
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestTicker(t *testing.T) {
|
||||
ticker := NewTicker(3 * time.Second)
|
||||
go func() {
|
||||
|
||||
@@ -63,7 +63,9 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
if timeout <= 0 {
|
||||
timeout = 60 // 默认封锁60秒
|
||||
}
|
||||
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout))
|
||||
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, group.Id, set.Id)
|
||||
|
||||
if writer != nil {
|
||||
// close the connection
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var captchaSalt = stringutil.Rand(32)
|
||||
|
||||
const (
|
||||
CaptchaSeconds = 600 // 10 minutes
|
||||
CaptchaPath = "/WAF/VERIFY/CAPTCHA"
|
||||
@@ -44,7 +42,7 @@ func (this *CaptchaAction) WillChange() bool {
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -65,6 +63,8 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
"action": this,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"url": refURL,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -47,7 +48,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -56,6 +57,8 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"scope": this.Scope,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
|
||||
@@ -3,6 +3,7 @@ package waf
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -27,7 +28,7 @@ func (this *GoGroupAction) WillChange() bool {
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package waf
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -28,11 +29,11 @@ func (this *GoSetAction) WillChange() bool {
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
nextSet := nextGroup.FindRuleSet(this.SetId)
|
||||
nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId))
|
||||
if nextSet == nil || !nextSet.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@@ -41,7 +42,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -55,7 +56,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
var setId = m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), m.GetInt64("groupId"), m.GetInt64("setId"))
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -64,6 +65,8 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"scope": this.Scope,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
"remoteIP": request.WAFRemoteIP(),
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,6 +19,12 @@ type recordIPTask struct {
|
||||
listId int64
|
||||
expiredAt int64
|
||||
level string
|
||||
serverId int64
|
||||
|
||||
sourceServerId int64
|
||||
sourceHTTPFirewallPolicyId int64
|
||||
sourceHTTPFirewallRuleGroupId int64
|
||||
sourceHTTPFirewallRuleSetId int64
|
||||
}
|
||||
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 1024)
|
||||
@@ -35,13 +44,19 @@ func init() {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiredAt,
|
||||
Reason: "触发WAF规则自动加入",
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiredAt,
|
||||
Reason: "触发WAF规则自动加入",
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
|
||||
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
@@ -79,7 +94,7 @@ func (this *RecordIPAction) WillChange() bool {
|
||||
|
||||
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在本地白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -98,17 +113,27 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
|
||||
} else {
|
||||
// 加入本地白名单
|
||||
SharedIPWhiteList.Add("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
|
||||
SharedIPWhiteList.Add("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
|
||||
}
|
||||
|
||||
// 上报
|
||||
if this.IPListId > 0 {
|
||||
var serverId int64
|
||||
if this.Scope == firewallconfigs.FirewallScopeService {
|
||||
serverId = request.WAFServerId()
|
||||
}
|
||||
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
level: this.Level,
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
level: this.Level,
|
||||
serverId: serverId,
|
||||
sourceServerId: request.WAFServerId(),
|
||||
sourceHTTPFirewallPolicyId: waf.Id,
|
||||
sourceHTTPFirewallRuleGroupId: group.Id,
|
||||
sourceHTTPFirewallRuleSetId: set.Id,
|
||||
}:
|
||||
default:
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ func (this *CaptchaValidator) Run(request requests.Request, writer http.Response
|
||||
var originURL = m.GetString("url")
|
||||
|
||||
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(actionConfig, setId, originURL, request, writer)
|
||||
this.validate(actionConfig, m.GetInt64("policyId"), m.GetInt64("groupId"), setId, originURL, request, writer)
|
||||
} else {
|
||||
this.show(actionConfig, request, writer)
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
|
||||
</html>`))
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
if len(captchaId) > 0 {
|
||||
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
@@ -143,7 +143,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64,
|
||||
}
|
||||
|
||||
// 加入到白名单
|
||||
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life))
|
||||
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, groupId, setId)
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
setId := m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
SharedIPWhiteList.RecordIP("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), m.GetInt64("groupId"), m.GetInt64("setId"))
|
||||
|
||||
// 返回原始URL
|
||||
var url = m.GetString("url")
|
||||
|
||||
@@ -10,8 +10,15 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var SharedIPWhiteList = NewIPList()
|
||||
var SharedIPBlackList = NewIPList()
|
||||
var SharedIPWhiteList = NewIPList(IPListTypeAllow)
|
||||
var SharedIPBlackList = NewIPList(IPListTypeDeny)
|
||||
|
||||
type IPListType = string
|
||||
|
||||
const (
|
||||
IPListTypeAllow IPListType = "allow"
|
||||
IPListTypeDeny IPListType = "deny"
|
||||
)
|
||||
|
||||
const IPTypeAll = "*"
|
||||
|
||||
@@ -20,16 +27,18 @@ type IPList struct {
|
||||
expireList *expires.List
|
||||
ipMap map[string]int64 // ip => id
|
||||
idMap map[int64]string // id => ip
|
||||
listType IPListType
|
||||
|
||||
id int64
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// NewIPList 获取新对象
|
||||
func NewIPList() *IPList {
|
||||
func NewIPList(listType IPListType) *IPList {
|
||||
var list = &IPList{
|
||||
ipMap: map[string]int64{},
|
||||
idMap: map[int64]string{},
|
||||
ipMap: map[string]int64{},
|
||||
idMap: map[int64]string{},
|
||||
listType: listType,
|
||||
}
|
||||
|
||||
e := expires.NewList()
|
||||
@@ -63,6 +72,29 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// RecordIP 记录IP
|
||||
func (this *IPList) RecordIP(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64, policyId int64, groupId int64, setId int64) {
|
||||
this.Add(ipType, scope, serverId, ip, expiresAt)
|
||||
|
||||
if this.listType == IPListTypeDeny {
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: ip,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiredAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
serverId: serverId,
|
||||
sourceServerId: serverId,
|
||||
sourceHTTPFirewallPolicyId: policyId,
|
||||
sourceHTTPFirewallRuleGroupId: groupId,
|
||||
sourceHTTPFirewallRuleSetId: setId,
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Contains 判断是否有某个IP
|
||||
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
|
||||
switch scope {
|
||||
@@ -80,6 +112,16 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
|
||||
return ok
|
||||
}
|
||||
|
||||
// RemoveIP 删除IP
|
||||
func (this *IPList) RemoveIP(ip string, serverId int64) {
|
||||
this.locker.Lock()
|
||||
delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
|
||||
if serverId > 0 {
|
||||
delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *IPList) remove(id int64) {
|
||||
this.locker.Lock()
|
||||
ip, ok := this.idMap[id]
|
||||
|
||||
@@ -39,6 +39,7 @@ func TestIPList_Contains(t *testing.T) {
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
//list.RemoveIP("192.168.1.100")
|
||||
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
|
||||
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
// rule group
|
||||
type RuleGroup struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
Id int64 `yaml:"id" json:"id"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Name string `yaml:"name" json:"name"` // such as SQL Injection
|
||||
Description string `yaml:"description" json:"description"`
|
||||
@@ -41,10 +41,7 @@ func (this *RuleGroup) AddRuleSet(ruleSet *RuleSet) {
|
||||
this.RuleSets = append(this.RuleSets, ruleSet)
|
||||
}
|
||||
|
||||
func (this *RuleGroup) FindRuleSet(id string) *RuleSet {
|
||||
if len(id) == 0 {
|
||||
return nil
|
||||
}
|
||||
func (this *RuleGroup) FindRuleSet(id int64) *RuleSet {
|
||||
for _, ruleSet := range this.RuleSets {
|
||||
if ruleSet.Id == id {
|
||||
return ruleSet
|
||||
@@ -65,10 +62,7 @@ func (this *RuleGroup) FindRuleSetWithCode(code string) *RuleSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RuleGroup) RemoveRuleSet(id string) {
|
||||
if len(id) == 0 {
|
||||
return
|
||||
}
|
||||
func (this *RuleGroup) RemoveRuleSet(id int64) {
|
||||
result := []*RuleSet{}
|
||||
for _, ruleSet := range this.RuleSets {
|
||||
if ruleSet.Id == id {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"sort"
|
||||
)
|
||||
@@ -19,7 +18,7 @@ const (
|
||||
)
|
||||
|
||||
type RuleSet struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
Id int64 `yaml:"id" json:"id"`
|
||||
Code string `yaml:"code" json:"code"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
@@ -36,7 +35,6 @@ type RuleSet struct {
|
||||
|
||||
func NewRuleSet() *RuleSet {
|
||||
return &RuleSet{
|
||||
Id: stringutil.Rand(16),
|
||||
IsOn: true,
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user