Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20bee16d28 | ||
|
|
b1cd971a21 | ||
|
|
d976a39711 | ||
|
|
accd0236ea | ||
|
|
fc401a1426 | ||
|
|
7e8c09a684 | ||
|
|
37ddff86f1 | ||
|
|
4dc25fb71e | ||
|
|
42883fbe22 | ||
|
|
a88d9a07be | ||
|
|
544f1e482a | ||
|
|
f53d4c8951 | ||
|
|
70d8aa5b33 | ||
|
|
1aa4be9000 | ||
|
|
a7c7c73f70 | ||
|
|
0b441021d8 | ||
|
|
7db0c8cf62 | ||
|
|
6da9cb6dcf | ||
|
|
0af580eb26 | ||
|
|
52085bdc1c | ||
|
|
72f1eea721 | ||
|
|
6d52b022b2 | ||
|
|
ea41c9b0b3 | ||
|
|
ed6127c2bb | ||
|
|
b6d95a84fc | ||
|
|
c71e68bdea | ||
|
|
c44583f249 | ||
|
|
c53773c2db | ||
|
|
793994a3fe | ||
|
|
4c3deb1156 | ||
|
|
24ca5a5ace | ||
|
|
8bbbf57827 | ||
|
|
888df02d0c | ||
|
|
8988765cef | ||
|
|
f675b88761 | ||
|
|
9bd4975478 | ||
|
|
95abb7bfae | ||
|
|
d9fa3dcc3b | ||
|
|
964524816f | ||
|
|
d124c9be18 | ||
|
|
1a05402076 | ||
|
|
c4b1790102 | ||
|
|
613acbff95 | ||
|
|
e6ab98ad11 | ||
|
|
1121869f14 | ||
|
|
91efe57e1b | ||
|
|
95f2573263 | ||
|
|
09aa85f51c | ||
|
|
c6279a1076 | ||
|
|
47ccb64cfb | ||
|
|
5c218567e1 | ||
|
|
c161d84fdf | ||
|
|
495b553285 | ||
|
|
21b770ba8b | ||
|
|
e9f94e0767 |
@@ -1 +1,2 @@
|
||||
* `global.yaml` - 全局配置
|
||||
* `api.template.yaml` - API相关配置模板
|
||||
* `cluster.template.yaml` - 通过集群自动接入节点模板
|
||||
@@ -1,7 +1,7 @@
|
||||
package caches
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -36,7 +36,7 @@ type Item struct {
|
||||
}
|
||||
|
||||
func (this *Item) IsExpired() bool {
|
||||
return this.ExpiredAt < utils.UnixTime()
|
||||
return this.ExpiredAt < fasttime.Now().Unix()
|
||||
}
|
||||
|
||||
func (this *Item) TotalSize() int64 {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
@@ -225,6 +226,20 @@ func (this *FileListDB) Init() error {
|
||||
err := this.hashMap.Load(this)
|
||||
if err != nil {
|
||||
remotelogs.Error("LIST_FILE_DB", "load hash map failed: "+err.Error()+"(file: "+this.dbPath+")")
|
||||
|
||||
// 自动修复错误
|
||||
// TODO 将来希望能尽可能恢复以往数据库中的内容
|
||||
if strings.Contains(err.Error(), "database is closed") || strings.Contains(err.Error(), "database disk image is malformed") {
|
||||
_ = this.Close()
|
||||
this.deleteDB()
|
||||
remotelogs.Println("LIST_FILE_DB", "recreating the database (file:"+this.dbPath+") ...")
|
||||
err = this.Open(this.dbPath)
|
||||
if err != nil {
|
||||
remotelogs.Error("LIST_FILE_DB", "recreate the database failed: "+err.Error()+" (file:"+this.dbPath+")")
|
||||
} else {
|
||||
_ = this.Init()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -246,7 +261,7 @@ func (this *FileListDB) AddAsync(hash string, item *Item) error {
|
||||
item.StaleAt = item.ExpiredAt
|
||||
}
|
||||
|
||||
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
|
||||
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, fasttime.Now().Unix(), timeutil.Format("YW"))
|
||||
return nil
|
||||
|
||||
}
|
||||
@@ -258,7 +273,7 @@ func (this *FileListDB) AddSync(hash string, item *Item) error {
|
||||
item.StaleAt = item.ExpiredAt
|
||||
}
|
||||
|
||||
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
|
||||
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, fasttime.Now().Unix(), timeutil.Format("YW"))
|
||||
if err != nil {
|
||||
return this.WrapError(err)
|
||||
}
|
||||
@@ -377,8 +392,8 @@ func (this *FileListDB) CleanPrefix(prefix string) error {
|
||||
return nil
|
||||
}
|
||||
var count = int64(10000)
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = fasttime.Now().Unix() // 只删除当前的,不删除新的
|
||||
for {
|
||||
result, err := this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET expiredAt=0,staleAt=? WHERE id IN (SELECT id FROM "`+this.itemsTableName+`" WHERE expiredAt>0 AND createdAt<=? AND INSTR("key", ?)=1 LIMIT `+types.String(count)+`)`, unixTime+int64(staleLife), unixTime, prefix)
|
||||
if err != nil {
|
||||
@@ -424,8 +439,8 @@ func (this *FileListDB) CleanMatchKey(key string) error {
|
||||
queryKey = strings.Replace(queryKey, "*", "%", 1)
|
||||
|
||||
// TODO 检查大批量数据下的操作性能
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = fasttime.Now().Unix() // 只删除当前的,不删除新的
|
||||
|
||||
_, err = this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET "expiredAt"=0, "staleAt"=? WHERE "host" GLOB ? AND "host" NOT GLOB ? AND "key" LIKE ? ESCAPE '\'`, unixTime+int64(staleLife), host, "*."+host, queryKey)
|
||||
if err != nil {
|
||||
@@ -466,8 +481,8 @@ func (this *FileListDB) CleanMatchPrefix(prefix string) error {
|
||||
queryPrefix += "%"
|
||||
|
||||
// TODO 检查大批量数据下的操作性能
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = utils.UnixTime() // 只删除当前的,不删除新的
|
||||
var staleLife = 600 // TODO 需要可以设置
|
||||
var unixTime = fasttime.Now().Unix() // 只删除当前的,不删除新的
|
||||
|
||||
_, err = this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET "expiredAt"=0, "staleAt"=? WHERE "host" GLOB ? AND "host" NOT GLOB ? AND "key" LIKE ? ESCAPE '\'`, unixTime+int64(staleLife), host, "*."+host, queryPrefix)
|
||||
return err
|
||||
@@ -682,3 +697,10 @@ func (this *FileListDB) shouldRecover() bool {
|
||||
_ = result.Close()
|
||||
return shouldRecover
|
||||
}
|
||||
|
||||
// 删除数据库文件
|
||||
func (this *FileListDB) deleteDB() {
|
||||
_ = os.Remove(this.dbPath)
|
||||
_ = os.Remove(this.dbPath + "-shm")
|
||||
_ = os.Remove(this.dbPath + "-wal")
|
||||
}
|
||||
|
||||
@@ -177,10 +177,15 @@ func (this *Manager) TotalDiskSize() int64 {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
total := int64(0)
|
||||
var total = int64(0)
|
||||
for _, storage := range this.storageMap {
|
||||
total += storage.TotalDiskSize()
|
||||
}
|
||||
|
||||
if total < 0 {
|
||||
total = 0
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package caches
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/linkedlist"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ func NewOpenFilePool(filename string) *OpenFilePool {
|
||||
var pool = &OpenFilePool{
|
||||
filename: filename,
|
||||
c: make(chan *OpenFile, 1024),
|
||||
version: utils.UnixTimeMilli(),
|
||||
version: fasttime.Now().UnixMilli(),
|
||||
}
|
||||
pool.linkItem = linkedlist.NewItem(pool)
|
||||
return pool
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
setutils "github.com/TeaOSLab/EdgeNode/internal/utils/sets"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/sizes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
@@ -32,7 +33,7 @@ type MemoryItem struct {
|
||||
}
|
||||
|
||||
func (this *MemoryItem) IsExpired() bool {
|
||||
return this.ExpiresAt < utils.UnixTime()
|
||||
return this.ExpiresAt < fasttime.Now().Unix()
|
||||
}
|
||||
|
||||
type MemoryStorage struct {
|
||||
@@ -119,7 +120,7 @@ func (this *MemoryStorage) OpenReader(key string, useStale bool, isPartial bool)
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
if useStale || (item.ExpiresAt > utils.UnixTime()) {
|
||||
if useStale || (item.ExpiresAt > fasttime.Now().Unix()) {
|
||||
reader := NewMemoryReader(item)
|
||||
err := reader.Init()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.6.4"
|
||||
Version = "1.0.4"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -5,6 +5,7 @@ package teaconst
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -15,7 +16,7 @@ var (
|
||||
|
||||
NodeId int64 = 0
|
||||
NodeIdString = ""
|
||||
IsMain = len(os.Args) == 1 || (len(os.Args) >= 2 && os.Args[1] == "pprof")
|
||||
IsMain = checkMain()
|
||||
|
||||
GlobalProductName = nodeconfigs.DefaultProductName
|
||||
|
||||
@@ -24,3 +25,15 @@ var (
|
||||
|
||||
DiskIsFast = false // 是否为高速硬盘
|
||||
)
|
||||
|
||||
// 检查是否为主程序
|
||||
func checkMain() bool {
|
||||
if len(os.Args) == 1 ||
|
||||
(len(os.Args) >= 2 && os.Args[1] == "pprof") {
|
||||
return true
|
||||
}
|
||||
exe, _ := os.Executable()
|
||||
return strings.HasSuffix(exe, ".test") ||
|
||||
strings.HasSuffix(exe, ".test.exe") ||
|
||||
strings.Contains(exe, "___")
|
||||
}
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -61,6 +61,8 @@ func init() {
|
||||
type DDoSProtectionManager struct {
|
||||
lastAllowIPList []string
|
||||
lastConfig []byte
|
||||
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
// NewDDoSProtectionManager 获取新对象
|
||||
@@ -70,6 +72,12 @@ func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
|
||||
// Apply 应用配置
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
// 加锁防止并发更改
|
||||
if !this.locker.TryLock() {
|
||||
return nil
|
||||
}
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 同集群节点IP白名单
|
||||
var allowIPListChanged = false
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
@@ -91,7 +99,7 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
|
||||
}
|
||||
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||
|
||||
if len(this.nftExe()) == 0 {
|
||||
if len(nftables.NftExePath()) == 0 {
|
||||
return errors.New("can not find nft command")
|
||||
}
|
||||
|
||||
@@ -157,7 +165,7 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
|
||||
|
||||
// 添加TCP规则
|
||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||
var nftExe = this.nftExe()
|
||||
var nftExe = nftables.NftExePath()
|
||||
if len(nftExe) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -546,7 +554,7 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
_, ok := oldMap[ip]
|
||||
if !ok {
|
||||
// 不存在则添加
|
||||
err = set.AddIPElement(ip, nil)
|
||||
err = set.AddIPElement(ip, nil, false)
|
||||
if err != nil {
|
||||
return errors.New("add ip '" + ip + "' failed: " + err.Error())
|
||||
}
|
||||
@@ -557,8 +565,3 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DDoSProtectionManager) nftExe() string {
|
||||
path, _ := exec.LookPath("nft")
|
||||
return path
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -37,8 +37,8 @@ func init() {
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
var nftExe = nftables.NftExePath()
|
||||
if len(nftExe) > 0 {
|
||||
nftablesFirewall, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -88,11 +88,15 @@ type blockIPItem struct {
|
||||
}
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
conn, err := nftables.NewConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var firewall = &NFTablesFirewall{
|
||||
conn: nftables.NewConn(),
|
||||
conn: conn,
|
||||
dropIPQueue: make(chan *blockIPItem, 4096),
|
||||
}
|
||||
err := firewall.init()
|
||||
err = firewall.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -110,8 +114,8 @@ type NFTablesFirewall struct {
|
||||
allowIPv4Set *nftables.Set
|
||||
allowIPv6Set *nftables.Set
|
||||
|
||||
denyIPv4Set *nftables.Set
|
||||
denyIPv6Set *nftables.Set
|
||||
denyIPv4Sets []*nftables.Set
|
||||
denyIPv6Sets []*nftables.Set
|
||||
|
||||
firewalld *Firewalld
|
||||
|
||||
@@ -120,9 +124,9 @@ type NFTablesFirewall struct {
|
||||
|
||||
func (this *NFTablesFirewall) init() error {
|
||||
// check nft
|
||||
nftPath, err := exec.LookPath("nft")
|
||||
if err != nil {
|
||||
return errors.New("nft not found")
|
||||
var nftPath = nftables.NftExePath()
|
||||
if len(nftPath) == 0 {
|
||||
return errors.New("'nft' not found")
|
||||
}
|
||||
this.version = this.readVersion(nftPath)
|
||||
|
||||
@@ -186,7 +190,7 @@ func (this *NFTablesFirewall) init() error {
|
||||
|
||||
// allow set
|
||||
// "allow" should be always first
|
||||
for _, setAction := range []string{"allow", "deny"} {
|
||||
for _, setAction := range []string{"allow", "deny", "deny1", "deny2", "deny3", "deny4"} {
|
||||
var setName = setAction + "_set"
|
||||
|
||||
set, err := table.GetSet(setName)
|
||||
@@ -216,32 +220,42 @@ func (this *NFTablesFirewall) init() error {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv4Set = set
|
||||
} else {
|
||||
this.denyIPv4Set = set
|
||||
this.denyIPv4Sets = append(this.denyIPv4Sets, set)
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv6Set = set
|
||||
} else {
|
||||
this.denyIPv6Set = set
|
||||
this.denyIPv6Sets = append(this.denyIPv6Sets, set)
|
||||
}
|
||||
}
|
||||
|
||||
// rule
|
||||
var ruleName = []byte(setAction)
|
||||
rule, err := chain.GetRuleWithUserData(ruleName)
|
||||
|
||||
// 将以前的drop规则删掉,替换成后面的reject
|
||||
if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop {
|
||||
deleteErr := chain.DeleteRule(rule)
|
||||
if deleteErr == nil {
|
||||
err = nftables.ErrRuleNotFound
|
||||
rule = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
|
||||
rule, err = chain.AddRejectIPv4SetRule(setName, ruleName)
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
|
||||
rule, err = chain.AddRejectIPv6SetRule(setName, ruleName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -265,7 +279,7 @@ func (this *NFTablesFirewall) init() error {
|
||||
for ipItem := range this.dropIPQueue {
|
||||
switch ipItem.action {
|
||||
case "drop":
|
||||
err = this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false)
|
||||
err := this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false)
|
||||
if err != nil {
|
||||
remotelogs.Warn("NFTABLES", "drop ip '"+ipItem.ip+"' failed: "+err.Error())
|
||||
}
|
||||
@@ -324,14 +338,14 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
if this.allowIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
||||
return this.allowIPv6Set.AddElement(data.To16(), nil, false)
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
||||
return this.allowIPv4Set.AddElement(data.To4(), nil, false)
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
@@ -371,22 +385,23 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
|
||||
// 再次尝试关闭连接
|
||||
defer conns.SharedMap.CloseIPConns(ip)
|
||||
|
||||
var ipLong = configutils.IPString2Long(ip)
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
if len(this.denyIPv6Sets) == 0 {
|
||||
return errors.New("ipv6 ip set not found")
|
||||
}
|
||||
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
||||
return this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].AddElement(data.To16(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}, false)
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.denyIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
if len(this.denyIPv4Sets) == 0 {
|
||||
return errors.New("ipv4 ip set not found")
|
||||
}
|
||||
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
||||
return this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].AddElement(data.To4(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}, false)
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
@@ -396,9 +411,10 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
var ipLong = configutils.IPString2Long(ip)
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set != nil {
|
||||
err := this.denyIPv6Set.DeleteElement(data.To16())
|
||||
if len(this.denyIPv6Sets) > 0 {
|
||||
err := this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -415,13 +431,14 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set != nil {
|
||||
err := this.denyIPv4Set.DeleteElement(data.To4())
|
||||
if len(this.denyIPv4Sets) > 0 {
|
||||
err := this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.allowIPv4Set.DeleteElement(data.To4())
|
||||
}
|
||||
if this.allowIPv4Set != nil {
|
||||
err := this.allowIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
@@ -11,7 +10,10 @@ import (
|
||||
)
|
||||
|
||||
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
||||
var conn = nftables.NewConn()
|
||||
conn, err := nftables.NewConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
|
||||
@@ -15,10 +15,14 @@ type Conn struct {
|
||||
rawConn *nft.Conn
|
||||
}
|
||||
|
||||
func NewConn() *Conn {
|
||||
return &Conn{
|
||||
rawConn: &nft.Conn{},
|
||||
func NewConn() (*Conn, error) {
|
||||
conn, err := nft.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
rawConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (this *Conn) Raw() *nft.Conn {
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
|
||||
package nftables
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrChainNotFound = errors.New("chain not found")
|
||||
@@ -15,5 +18,5 @@ func IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
|
||||
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound || strings.Contains(err.Error(), "no such file or directory")
|
||||
}
|
||||
|
||||
65
internal/firewalls/nftables/expration.go
Normal file
65
internal/firewalls/nftables/expration.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Expiration struct {
|
||||
m map[string]time.Time // key => expires time
|
||||
|
||||
lastGCAt int64
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewExpiration() *Expiration {
|
||||
return &Expiration{
|
||||
m: map[string]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Expiration) AddUnsafe(key []byte, expires time.Time) {
|
||||
this.m[string(key)] = expires
|
||||
}
|
||||
|
||||
func (this *Expiration) Add(key []byte, expires time.Time) {
|
||||
this.locker.Lock()
|
||||
this.m[string(key)] = expires
|
||||
this.gc()
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *Expiration) Remove(key []byte) {
|
||||
this.locker.Lock()
|
||||
delete(this.m, string(key))
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *Expiration) Contains(key []byte) bool {
|
||||
this.locker.RLock()
|
||||
expires, ok := this.m[string(key)]
|
||||
if ok && expires.Year() > 2000 && time.Now().After(expires) {
|
||||
ok = false
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (this *Expiration) gc() {
|
||||
// we won't gc too frequently
|
||||
var currentTime = time.Now().Unix()
|
||||
if this.lastGCAt >= currentTime {
|
||||
return
|
||||
}
|
||||
this.lastGCAt = currentTime
|
||||
|
||||
var now = time.Now().Add(-10 * time.Second) // gc elements expired before 10 seconds ago
|
||||
for key, expires := range this.m {
|
||||
if expires.Year() > 2000 && now.After(expires) {
|
||||
delete(this.m, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
59
internal/firewalls/nftables/expration_test.go
Normal file
59
internal/firewalls/nftables/expration_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExpiration_Add(t *testing.T) {
|
||||
var expiration = nftables.NewExpiration()
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Now())
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(1*time.Second))
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Time{})
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-1*time.Second))
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(-10*time.Second))
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add([]byte{'a', 'b', 'c'}, time.Now().Add(1*time.Second))
|
||||
expiration.Remove([]byte{'a', 'b', 'c'})
|
||||
t.Log(expiration.Contains([]byte{'a', 'b', 'c'}))
|
||||
}
|
||||
{
|
||||
expiration.Add(net.ParseIP("10.254.0.75").To4(), time.Now())
|
||||
t.Log(expiration.Contains(net.ParseIP("10.254.0.75").To4()))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewExpiration(b *testing.B) {
|
||||
var expiration = nftables.NewExpiration()
|
||||
for i := 0; i < 10_000; i++ {
|
||||
expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second))
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
expiration.Add([]byte(types.String(types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)))), time.Now().Add(3600*time.Second))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
@@ -38,8 +39,7 @@ func init() {
|
||||
}
|
||||
|
||||
if os.Getgid() == 0 { // root user only
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
if len(NftExePath()) > 0 {
|
||||
return
|
||||
}
|
||||
goman.New(func() {
|
||||
@@ -53,6 +53,25 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
// NftExePath 查找nftables可执行文件路径
|
||||
func NftExePath() string {
|
||||
path, _ := exec.LookPath("nft")
|
||||
if len(path) > 0 {
|
||||
return path
|
||||
}
|
||||
|
||||
for _, possiblePath := range []string{
|
||||
"/usr/sbin/nft",
|
||||
} {
|
||||
_, err := os.Stat(possiblePath)
|
||||
if err == nil {
|
||||
return possiblePath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
type Installer struct {
|
||||
}
|
||||
|
||||
@@ -67,8 +86,7 @@ func (this *Installer) Install() error {
|
||||
}
|
||||
|
||||
// 检查是否已经存在
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
if len(NftExePath()) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
@@ -35,17 +34,25 @@ type Set struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
batch *SetBatch
|
||||
|
||||
expiration *Expiration
|
||||
}
|
||||
|
||||
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
|
||||
return &Set{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
var set = &Set{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
expiration: nil,
|
||||
batch: &SetBatch{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
},
|
||||
}
|
||||
|
||||
// retrieve set elements to improve "delete" speed
|
||||
set.initElements()
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
func (this *Set) Raw() *nft.Set {
|
||||
@@ -56,12 +63,22 @@ func (this *Set) Name() string {
|
||||
return this.rawSet.Name
|
||||
}
|
||||
|
||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
|
||||
// check if already exists
|
||||
if this.expiration != nil && !overwrite && this.expiration.Contains(key) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expiresTime = time.Time{}
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
|
||||
if options.Timeout > 0 {
|
||||
expiresTime = time.UnixMilli(time.Now().UnixMilli() + options.Timeout.Milliseconds())
|
||||
}
|
||||
}
|
||||
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
@@ -71,9 +88,19 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
if this.expiration != nil {
|
||||
this.expiration.Add(key, expiresTime)
|
||||
}
|
||||
} else {
|
||||
var isFileExistsErr = strings.Contains(err.Error(), "file exists")
|
||||
if !overwrite && isFileExistsErr {
|
||||
// ignore file exists error
|
||||
return nil
|
||||
}
|
||||
|
||||
// retry if exists
|
||||
if strings.Contains(err.Error(), "file exists") {
|
||||
if overwrite && isFileExistsErr {
|
||||
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
@@ -85,6 +112,11 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
})
|
||||
if err == nil {
|
||||
err = this.conn.Commit()
|
||||
if err == nil {
|
||||
if this.expiration != nil {
|
||||
this.expiration.Add(key, expiresTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,20 +125,25 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
|
||||
func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.AddElement(ipObj.To4(), options)
|
||||
return this.AddElement(ipObj.To4(), options, overwrite)
|
||||
} else {
|
||||
return this.AddElement(ipObj.To16(), options)
|
||||
return this.AddElement(ipObj.To16(), options, overwrite)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) DeleteElement(key []byte) error {
|
||||
// if set element does not exist, we return immediately
|
||||
if this.expiration != nil && !this.expiration.Contains(key) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
@@ -116,9 +153,17 @@ func (this *Set) DeleteElement(key []byte) error {
|
||||
return err
|
||||
}
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
if this.expiration != nil {
|
||||
this.expiration.Remove(key)
|
||||
}
|
||||
} else {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
err = nil
|
||||
|
||||
if this.expiration != nil {
|
||||
this.expiration.Remove(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
8
internal/firewalls/nftables/set_ext.go
Normal file
8
internal/firewalls/nftables/set_ext.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build linux && !plus
|
||||
|
||||
package nftables
|
||||
|
||||
func (this *Set) initElements() {
|
||||
// NOT IMPLEMENTED
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func getIPv4Set(t *testing.T) *nftables.Set {
|
||||
|
||||
func TestSet_AddElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
|
||||
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
@@ -10,7 +9,10 @@ import (
|
||||
)
|
||||
|
||||
func getIPv4Table(t *testing.T) *nftables.Table {
|
||||
var conn = nftables.NewConn()
|
||||
conn, err := nftables.NewConn()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
|
||||
30
internal/firewalls/utils.go
Normal file
30
internal/firewalls/utils.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包
|
||||
func DropTemporaryTo(ip string, expiresAt int64) {
|
||||
// 如果为0,则表示是长期有效
|
||||
if expiresAt <= 0 {
|
||||
expiresAt = time.Now().Unix() + 3600
|
||||
}
|
||||
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout < 1 {
|
||||
return
|
||||
}
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package iplibrary
|
||||
|
||||
import "github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
)
|
||||
|
||||
type IPItemType = string
|
||||
|
||||
@@ -45,7 +47,7 @@ func (this *IPItem) containsIPv4(ip uint64) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -56,7 +58,7 @@ func (this *IPItem) containsIPv6(ip uint64) bool {
|
||||
if this.IPFrom != ip {
|
||||
return false
|
||||
}
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -64,7 +66,7 @@ func (this *IPItem) containsIPv6(ip uint64) bool {
|
||||
|
||||
// 检查是否包所有IP
|
||||
func (this *IPItem) containsAll() bool {
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
|
||||
if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -3,6 +3,7 @@ package iplibrary
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
@@ -72,6 +73,25 @@ func (this *IPList) Contains(ip uint64) bool {
|
||||
return item != nil
|
||||
}
|
||||
|
||||
// ContainsExpires 判断是否包含某个IP
|
||||
func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) {
|
||||
this.locker.RLock()
|
||||
if len(this.allItemsMap) > 0 {
|
||||
this.locker.RUnlock()
|
||||
return 0, true
|
||||
}
|
||||
|
||||
var item = this.lookupIP(ip)
|
||||
|
||||
this.locker.RUnlock()
|
||||
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
|
||||
return item.ExpiredAt, true
|
||||
}
|
||||
|
||||
// ContainsIPStrings 是否包含一组IP中的任意一个,并返回匹配的第一个Item
|
||||
func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) {
|
||||
if len(ipStrings) == 0 {
|
||||
@@ -110,7 +130,7 @@ func (this *IPList) addItem(item *IPItem, sortable bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if item.ExpiredAt > 0 && item.ExpiredAt < utils.UnixTime() {
|
||||
if item.ExpiredAt > 0 && item.ExpiredAt < fasttime.Now().Unix() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,7 +175,7 @@ func (this *IPList) addItem(item *IPItem, sortable bool) {
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// 对列表进行排序
|
||||
// 对列表进行排序
|
||||
func (this *IPList) sortItems() {
|
||||
sort.Slice(this.sortedItems, func(i, j int) bool {
|
||||
var item1 = this.sortedItems[i]
|
||||
|
||||
@@ -17,13 +17,17 @@ import (
|
||||
type IPListDB struct {
|
||||
db *dbs.DB
|
||||
|
||||
itemTableName string
|
||||
itemTableName string
|
||||
versionTableName string
|
||||
|
||||
deleteExpiredItemsStmt *dbs.Stmt
|
||||
deleteItemStmt *dbs.Stmt
|
||||
insertItemStmt *dbs.Stmt
|
||||
selectItemsStmt *dbs.Stmt
|
||||
selectMaxVersionStmt *dbs.Stmt
|
||||
deleteExpiredItemsStmt *dbs.Stmt
|
||||
deleteItemStmt *dbs.Stmt
|
||||
insertItemStmt *dbs.Stmt
|
||||
selectItemsStmt *dbs.Stmt
|
||||
selectMaxItemVersionStmt *dbs.Stmt
|
||||
|
||||
selectVersionStmt *dbs.Stmt
|
||||
updateVersionStmt *dbs.Stmt
|
||||
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
@@ -34,9 +38,10 @@ type IPListDB struct {
|
||||
|
||||
func NewIPListDB() (*IPListDB, error) {
|
||||
var db = &IPListDB{
|
||||
itemTableName: "ipItems",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
cleanTicker: time.NewTicker(24 * time.Hour),
|
||||
itemTableName: "ipItems",
|
||||
versionTableName: "versions",
|
||||
dir: filepath.Clean(Tea.Root + "/data"),
|
||||
cleanTicker: time.NewTicker(24 * time.Hour),
|
||||
}
|
||||
err := db.init()
|
||||
return db, err
|
||||
@@ -108,6 +113,15 @@ ON "` + this.itemTableName + `" (
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"version" integer DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化SQL语句
|
||||
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
|
||||
if err != nil {
|
||||
@@ -129,7 +143,20 @@ ON "` + this.itemTableName + `" (
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectMaxVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
|
||||
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
@@ -172,11 +199,15 @@ func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
|
||||
// 如果是删除,则不再创建新记录
|
||||
if item.IsDeleted {
|
||||
return nil
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.UpdateMaxVersion(item.Version)
|
||||
}
|
||||
|
||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
|
||||
@@ -210,27 +241,63 @@ func (this *IPListDB) ReadMaxVersion() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var row = this.selectMaxVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
// from version table
|
||||
{
|
||||
var row = this.selectVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err == nil {
|
||||
return version
|
||||
}
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err != nil {
|
||||
return 0
|
||||
|
||||
// from items table
|
||||
{
|
||||
var row = this.selectMaxItemVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
}
|
||||
var version int64
|
||||
err := row.Scan(&version)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return version
|
||||
}
|
||||
return version
|
||||
}
|
||||
|
||||
// UpdateMaxVersion 修改版本号
|
||||
func (this *IPListDB) UpdateMaxVersion(version int64) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.updateVersionStmt.Exec(version)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
if this.db != nil {
|
||||
_ = this.deleteExpiredItemsStmt.Close()
|
||||
_ = this.deleteItemStmt.Close()
|
||||
_ = this.insertItemStmt.Close()
|
||||
_ = this.selectItemsStmt.Close()
|
||||
_ = this.selectMaxVersionStmt.Close()
|
||||
for _, stmt := range []*dbs.Stmt{
|
||||
this.deleteExpiredItemsStmt,
|
||||
this.deleteItemStmt,
|
||||
this.insertItemStmt,
|
||||
this.selectItemsStmt,
|
||||
this.selectMaxItemVersionStmt, // ipItems table
|
||||
|
||||
this.selectVersionStmt, // versions table
|
||||
this.updateVersionStmt,
|
||||
} {
|
||||
if stmt != nil {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
|
||||
@@ -79,3 +79,15 @@ func TestIPListDB_ReadMaxVersion(t *testing.T) {
|
||||
}
|
||||
t.Log(db.ReadMaxVersion())
|
||||
}
|
||||
|
||||
func TestIPListDB_UpdateMaxVersion(t *testing.T) {
|
||||
db, err := iplibrary.NewIPListDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.UpdateMaxVersion(1027)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(db.ReadMaxVersion())
|
||||
}
|
||||
|
||||
@@ -10,50 +10,54 @@ import (
|
||||
|
||||
// AllowIP 检查IP是否被允许访问
|
||||
// 如果一个IP不在任何名单中,则允许访问
|
||||
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
|
||||
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expiresAt int64) {
|
||||
if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试
|
||||
// 放行lo
|
||||
if ip == "127.0.0.1" || ip == "::1" {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
|
||||
// check node
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
}
|
||||
|
||||
var ipLong = utils.IP2Long(ip)
|
||||
if ipLong == 0 {
|
||||
return false, false
|
||||
return false, false, 0
|
||||
}
|
||||
|
||||
// check white lists
|
||||
if GlobalWhiteIPList.Contains(ipLong) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindWhiteList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
}
|
||||
|
||||
// check black lists
|
||||
if GlobalBlackIPList.Contains(ipLong) {
|
||||
return false, false
|
||||
expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipLong)
|
||||
if ok {
|
||||
return false, false, expiresAt
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindBlackList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return false, false
|
||||
if list != nil {
|
||||
expiresAt, ok = list.ContainsExpires(ipLong)
|
||||
if ok {
|
||||
return false, false, expiresAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, false
|
||||
return true, false, 0
|
||||
}
|
||||
|
||||
// IsInWhiteList 检查IP是否在白名单中
|
||||
@@ -73,7 +77,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool {
|
||||
return true
|
||||
}
|
||||
for _, ip := range ipStrings {
|
||||
isAllowed, _ := AllowIP(ip, serverId)
|
||||
isAllowed, _, _ := AllowIP(ip, serverId)
|
||||
if !isAllowed {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -47,17 +47,20 @@ type IPListManager struct {
|
||||
|
||||
db *IPListDB
|
||||
|
||||
version int64
|
||||
pageSize int64
|
||||
lastVersion int64
|
||||
fetchPageSize int64
|
||||
|
||||
listMap map[int64]*IPList
|
||||
locker sync.Mutex
|
||||
|
||||
isFirstTime bool
|
||||
}
|
||||
|
||||
func NewIPListManager() *IPListManager {
|
||||
return &IPListManager{
|
||||
pageSize: 1000,
|
||||
listMap: map[int64]*IPList{},
|
||||
fetchPageSize: 5_000,
|
||||
listMap: map[int64]*IPList{},
|
||||
isFirstTime: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,11 +120,11 @@ func (this *IPListManager) init() {
|
||||
_ = db.DeleteExpiredItems()
|
||||
|
||||
// 本地数据库中最大版本号
|
||||
this.version = db.ReadMaxVersion()
|
||||
this.lastVersion = db.ReadMaxVersion()
|
||||
|
||||
// 从本地数据库中加载
|
||||
var offset int64 = 0
|
||||
var size int64 = 1000
|
||||
var size int64 = 2_000
|
||||
for {
|
||||
items, err := db.ReadItems(offset, size)
|
||||
var l = len(items)
|
||||
@@ -148,6 +151,11 @@ func (this *IPListManager) loop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 第一次同步则打印信息
|
||||
if this.isFirstTime {
|
||||
remotelogs.Println("IP_LIST_MANAGER", "initializing ip items ...")
|
||||
}
|
||||
|
||||
for {
|
||||
hasNext, err := this.fetch()
|
||||
if err != nil {
|
||||
@@ -159,6 +167,12 @@ func (this *IPListManager) loop() error {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// 第一次同步则打印信息
|
||||
if this.isFirstTime {
|
||||
this.isFirstTime = false
|
||||
remotelogs.Println("IP_LIST_MANAGER", "finished initializing ip items")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -168,8 +182,8 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
itemsResp, err := rpcClient.IPItemRPC.ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: this.pageSize,
|
||||
Version: this.lastVersion,
|
||||
Size: this.fetchPageSize,
|
||||
})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
@@ -211,6 +225,7 @@ func (this *IPListManager) DeleteExpiredItems() {
|
||||
}
|
||||
}
|
||||
|
||||
// 处理IP条目
|
||||
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
var changedLists = map[*IPList]zero.Zero{}
|
||||
for _, item := range items {
|
||||
@@ -280,8 +295,8 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
|
||||
if fromRemote {
|
||||
var latestVersion = items[len(items)-1].Version
|
||||
if latestVersion > this.version {
|
||||
this.version = latestVersion
|
||||
if latestVersion > this.lastVersion {
|
||||
this.lastVersion = latestVersion
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -33,6 +35,7 @@ type ClientConn struct {
|
||||
hasRead bool
|
||||
|
||||
isLO bool // 是否为环路
|
||||
isNoStat bool // 是否不统计带宽
|
||||
isInAllowList bool
|
||||
|
||||
hasResetSYNFlood bool
|
||||
@@ -52,15 +55,15 @@ type ClientConn struct {
|
||||
func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool) net.Conn {
|
||||
// 是否为环路
|
||||
var remoteAddr = rawConn.RemoteAddr().String()
|
||||
var isLO = strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:")
|
||||
|
||||
var conn = &ClientConn{
|
||||
BaseClientConn: BaseClientConn{rawConn: rawConn},
|
||||
isTLS: isTLS,
|
||||
isHTTP: isHTTP,
|
||||
isLO: isLO,
|
||||
isLO: strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:"),
|
||||
isNoStat: connutils.IsNoStatConn(rawConn.RemoteAddr().String()),
|
||||
isInAllowList: isInAllowList,
|
||||
createdAt: time.Now().Unix(),
|
||||
createdAt: fasttime.Now().Unix(),
|
||||
}
|
||||
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
@@ -84,7 +87,7 @@ func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
if this.isDebugging {
|
||||
this.lastReadAt = time.Now().Unix()
|
||||
this.lastReadAt = fasttime.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -150,7 +153,7 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
if this.isDebugging {
|
||||
this.lastWriteAt = time.Now().Unix()
|
||||
this.lastWriteAt = fasttime.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -182,14 +185,15 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
if n > 0 {
|
||||
// 统计当前服务带宽
|
||||
if this.serverId > 0 {
|
||||
if !this.isLO || Tea.IsTesting() { // 环路不统计带宽,避免缓存预热等行为产生带宽
|
||||
// TODO 需要加入在serverId绑定之前的带宽
|
||||
if !this.isNoStat || Tea.IsTesting() { // 环路不统计带宽,避免缓存预热等行为产生带宽
|
||||
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
|
||||
|
||||
var cost = time.Since(before).Seconds()
|
||||
if cost > 1 {
|
||||
stats.SharedBandwidthStatManager.Add(this.userId, this.serverId, int64(float64(n)/cost), int64(n))
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(float64(n)/cost), int64(n))
|
||||
} else {
|
||||
stats.SharedBandwidthStatManager.Add(this.userId, this.serverId, int64(n), int64(n))
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(n), int64(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -287,7 +291,7 @@ func (this *ClientConn) resetSYNFlood() {
|
||||
func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloodConfig) {
|
||||
var ip = this.RawIP()
|
||||
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
|
||||
var timestamp = utils.NextMinuteUnixTime()
|
||||
var timestamp = fasttime.Now().UnixNextMinute()
|
||||
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp, true)
|
||||
var minAttempts = synFloodConfig.MinAttempts
|
||||
if minAttempts < 5 {
|
||||
@@ -307,7 +311,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
|
||||
_ = this.SetLinger(0)
|
||||
_ = this.Close()
|
||||
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, time.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击,当前1分钟"+types.String(result)+"次空连接")
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, fasttime.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击,当前1分钟"+types.String(result)+"次空连接")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"net"
|
||||
)
|
||||
|
||||
@@ -48,7 +50,20 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS
|
||||
}
|
||||
|
||||
// SetServerId 设置服务ID
|
||||
func (this *BaseClientConn) SetServerId(serverId int64) {
|
||||
func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
|
||||
goNext = true
|
||||
|
||||
// 检查服务相关IP黑名单
|
||||
if serverId > 0 && len(this.rawIP) > 0 {
|
||||
// 是否在白名单中
|
||||
ok, _, expiresAt := iplibrary.AllowIP(this.rawIP, serverId)
|
||||
if !ok {
|
||||
_ = this.rawConn.Close()
|
||||
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
this.serverId = serverId
|
||||
|
||||
// 设置包装前连接
|
||||
@@ -61,6 +76,8 @@ func (this *BaseClientConn) SetServerId(serverId int64) {
|
||||
case *ClientConn:
|
||||
conn.SetServerId(serverId)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ServerId 读取当前连接绑定的服务ID
|
||||
|
||||
@@ -16,7 +16,7 @@ type ClientConnInterface interface {
|
||||
ServerId() int64
|
||||
|
||||
// SetServerId 设置服务ID
|
||||
SetServerId(serverId int64)
|
||||
SetServerId(serverId int64) (goNext bool)
|
||||
|
||||
// SetUserId 设置所属服务的用户ID
|
||||
SetUserId(userId int64)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
@@ -43,25 +42,17 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
var isInAllowList = false
|
||||
if err == nil {
|
||||
canGoNext, inAllowList := iplibrary.AllowIP(ip, 0)
|
||||
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
||||
isInAllowList = inAllowList
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout > 0 {
|
||||
if !canGoNext {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
} else {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
var ok = false
|
||||
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
canGoNext = false
|
||||
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
}
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var sharedHTTPAccessLogQueue = NewHTTPAccessLogQueue()
|
||||
@@ -137,14 +138,23 @@ Loop:
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToValidUTF8 处理访问日志中的非UTF-8字节
|
||||
func (this *HTTPAccessLogQueue) ToValidUTF8(accessLog *pb.HTTPAccessLog) {
|
||||
accessLog.RemoteAddr = utils.ToValidUTF8string(accessLog.RemoteAddr)
|
||||
accessLog.RemoteUser = utils.ToValidUTF8string(accessLog.RemoteUser)
|
||||
accessLog.RequestURI = utils.ToValidUTF8string(accessLog.RequestURI)
|
||||
accessLog.RequestPath = utils.ToValidUTF8string(accessLog.RequestPath)
|
||||
accessLog.RequestFilename = utils.ToValidUTF8string(accessLog.RequestFilename)
|
||||
accessLog.RequestBody = bytes.ToValidUTF8(accessLog.RequestBody, []byte{})
|
||||
accessLog.Host = utils.ToValidUTF8string(accessLog.Host)
|
||||
accessLog.Hostname = utils.ToValidUTF8string(accessLog.Hostname)
|
||||
|
||||
for k, v := range accessLog.SentHeader {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.SentHeader, k)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, v := range accessLog.SentHeader {
|
||||
for index, s := range v.Values {
|
||||
v.Values[index] = utils.ToValidUTF8string(s)
|
||||
}
|
||||
@@ -156,15 +166,27 @@ func (this *HTTPAccessLogQueue) ToValidUTF8(accessLog *pb.HTTPAccessLog) {
|
||||
accessLog.ContentType = utils.ToValidUTF8string(accessLog.ContentType)
|
||||
|
||||
for k, c := range accessLog.Cookie {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.Cookie, k)
|
||||
continue
|
||||
}
|
||||
accessLog.Cookie[k] = utils.ToValidUTF8string(c)
|
||||
}
|
||||
|
||||
accessLog.Args = utils.ToValidUTF8string(accessLog.Args)
|
||||
accessLog.QueryString = utils.ToValidUTF8string(accessLog.QueryString)
|
||||
|
||||
for _, v := range accessLog.Header {
|
||||
for k, v := range accessLog.Header {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.Header, k)
|
||||
continue
|
||||
}
|
||||
for index, s := range v.Values {
|
||||
v.Values[index] = utils.ToValidUTF8string(s)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range accessLog.Errors {
|
||||
accessLog.Errors[k] = utils.ToValidUTF8string(v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestHTTPAccessLogQueue_Push(t *testing.T) {
|
||||
@@ -135,6 +136,16 @@ func TestHTTPAccessLogQueue_Memory(t *testing.T) {
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
func TestUTF8_IsValid(t *testing.T) {
|
||||
t.Log(utf8.ValidString("abc"))
|
||||
|
||||
var noneUTF8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
noneUTF8Bytes = append(noneUTF8Bytes, uint8(i))
|
||||
}
|
||||
t.Log(utf8.ValidString(string(noneUTF8Bytes)))
|
||||
}
|
||||
|
||||
func BenchmarkHTTPAccessLogQueue_ToValidUTF8(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"net"
|
||||
@@ -61,7 +62,12 @@ func NewHTTPCacheTaskManager() *HTTPCacheTaskManager {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.Dial(network, "127.0.0.1:"+port)
|
||||
conn, err := net.Dial(network, "127.0.0.1:"+port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connutils.NewNoStat(conn), nil
|
||||
},
|
||||
MaxIdleConns: 128,
|
||||
MaxIdleConnsPerHost: 32,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ type HTTPClient struct {
|
||||
func NewHTTPClient(rawClient *http.Client) *HTTPClient {
|
||||
return &HTTPClient{
|
||||
rawClient: rawClient,
|
||||
accessAt: utils.UnixTime(),
|
||||
accessAt: fasttime.Now().Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func (this *HTTPClient) RawClient() *http.Client {
|
||||
|
||||
// UpdateAccessTime 更新访问时间
|
||||
func (this *HTTPClient) UpdateAccessTime() {
|
||||
this.accessAt = utils.UnixTime()
|
||||
this.accessAt = fasttime.Now().Unix()
|
||||
}
|
||||
|
||||
// AccessTime 获取访问时间
|
||||
|
||||
@@ -11,12 +11,12 @@ func TestHTTPClientPool_Client(t *testing.T) {
|
||||
pool := NewHTTPClientPool()
|
||||
|
||||
{
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
err := origin.Init(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
err := origin.Init(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -60,17 +60,19 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
func BenchmarkHTTPClientPool_Client(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
err := origin.Init(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pool := NewHTTPClientPool()
|
||||
b.ResetTimer()
|
||||
|
||||
var pool = NewHTTPClientPool()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
|
||||
}
|
||||
|
||||
@@ -403,7 +403,7 @@ func (this *HTTPRequest) doEnd() {
|
||||
attackBytes = this.CalculateSize()
|
||||
}
|
||||
|
||||
stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes()+this.writer.SentHeaderBytes(), cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, this.writer.SentBodyBytes()+this.writer.SentHeaderBytes(), cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
|
||||
|
||||
// 指标
|
||||
if metrics.SharedManager.HasHTTPMetrics() {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
@@ -328,7 +329,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
}
|
||||
|
||||
// 设置cache.age变量
|
||||
var age = strconv.FormatInt(utils.UnixTime()-reader.LastModified(), 10)
|
||||
var age = strconv.FormatInt(fasttime.Now().Unix()-reader.LastModified(), 10)
|
||||
this.varMapping["cache.age"] = age
|
||||
|
||||
if addStatusHeader {
|
||||
|
||||
@@ -21,6 +21,8 @@ func (this *HTTPRequest) doHealthCheck(key string, isHealthCheck *bool) (stop bo
|
||||
}
|
||||
*isHealthCheck = true
|
||||
|
||||
this.web.StatRef = nil
|
||||
|
||||
if !data.GetBool("accessLogIsOn") {
|
||||
this.disableLog = true
|
||||
}
|
||||
|
||||
@@ -25,6 +25,16 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
if !u.MatchRequest(this.Format) {
|
||||
continue
|
||||
}
|
||||
|
||||
var status = u.Status
|
||||
if status <= 0 {
|
||||
if searchEngineRegex.MatchString(this.RawReq.UserAgent()) {
|
||||
status = http.StatusMovedPermanently
|
||||
} else {
|
||||
status = http.StatusTemporaryRedirect
|
||||
}
|
||||
}
|
||||
|
||||
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
|
||||
if u.MatchPrefix { // 匹配前缀
|
||||
if strings.HasPrefix(fullURL, u.BeforeURL) {
|
||||
@@ -38,11 +48,8 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
} else if u.MatchRegexp { // 正则匹配
|
||||
@@ -83,11 +90,8 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
|
||||
return true
|
||||
} else { // 精准匹配
|
||||
if fullURL == u.RealBeforeURL() {
|
||||
@@ -104,11 +108,8 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -142,10 +143,8 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
|
||||
// 参数
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
@@ -153,7 +152,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
|
||||
@@ -200,11 +199,9 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
// 是否在全局名单中
|
||||
_, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
_, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -35,12 +35,13 @@ func (this *HTTPRequest) doMismatch() {
|
||||
if sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
// 检查cc
|
||||
// TODO 可以在管理端配置是否开启以及最多尝试次数
|
||||
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
|
||||
if len(remoteIP) > 0 {
|
||||
const maxAttempts = 100
|
||||
if ttlcache.SharedCache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
|
||||
// 在加入之前再次检查黑名单
|
||||
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+int64(3600), 0, true, 0, 0, "access mismatch domain '"+this.RawReq.Host+"' too frequently")
|
||||
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) (shouldBreak bool) {
|
||||
host := this.RawReq.Host
|
||||
var host = this.RawReq.Host
|
||||
|
||||
// 检查域名是否匹配
|
||||
if !redirectToHTTPSConfig.MatchDomain(host) {
|
||||
@@ -22,7 +22,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
|
||||
host = redirectToHTTPSConfig.Host
|
||||
}
|
||||
} else if redirectToHTTPSConfig.Port > 0 {
|
||||
lastIndex := strings.LastIndex(host, ":")
|
||||
var lastIndex = strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
@@ -30,18 +30,18 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
|
||||
host = host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
}
|
||||
} else {
|
||||
lastIndex := strings.LastIndex(host, ":")
|
||||
var lastIndex = strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
}
|
||||
|
||||
statusCode := http.StatusMovedPermanently
|
||||
var statusCode = http.StatusMovedPermanently
|
||||
if redirectToHTTPSConfig.Status > 0 {
|
||||
statusCode = redirectToHTTPSConfig.Status
|
||||
}
|
||||
|
||||
newURL := "https://" + host + this.RawReq.RequestURI
|
||||
var newURL = "https://" + host + this.RawReq.RequestURI
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
|
||||
|
||||
|
||||
@@ -21,13 +21,15 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
return
|
||||
}
|
||||
|
||||
var isLowVersionHTTP = this.RawReq.ProtoMajor < 1 /** 0.x **/ || (this.RawReq.ProtoMajor == 1 && this.RawReq.ProtoMinor == 0 /** 1.0 **/)
|
||||
|
||||
var retries = 3
|
||||
|
||||
var failedOriginIds []int64
|
||||
var failedLnNodeIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
originId, lnNodeId, shouldRetry := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1)
|
||||
originId, lnNodeId, shouldRetry := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1, isLowVersionHTTP)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
@@ -41,7 +43,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
}
|
||||
|
||||
// 请求源站
|
||||
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool) (originId int64, lnNodeId int64, shouldRetry bool) {
|
||||
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool, isLowVersionHTTP bool) (originId int64, lnNodeId int64, shouldRetry bool) {
|
||||
// 对URL的处理
|
||||
var stripPrefix = this.reverseProxy.StripPrefix
|
||||
var requestURI = this.reverseProxy.RequestURI
|
||||
@@ -321,6 +323,16 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
return
|
||||
}
|
||||
|
||||
// 是否为1.1以下
|
||||
if isLowVersionHTTP && resp.ContentLength < 0 {
|
||||
this.writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = this.writer.WriteString("The content does not support " + this.RawReq.Proto + " request.")
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 记录相关数据
|
||||
this.originStatus = int32(resp.StatusCode)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
// 统计
|
||||
func (this *HTTPRequest) doStat() {
|
||||
if this.ReqServer == nil {
|
||||
if this.ReqServer == nil || this.web == nil || this.web.StatRef == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
@@ -15,7 +15,11 @@ import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// 其中的每个括号里的内容都在被引用,不能轻易修改
|
||||
// 搜索引擎和爬虫正则
|
||||
var searchEngineRegex = regexp.MustCompile(`(?i)(60spider|adldxbot|adsbot-google|applebot|admantx|alexa|baidu|bingbot|bingpreview|facebookexternalhit|googlebot|proximic|slurp|sogou|twitterbot|yandex)`)
|
||||
var spiderRegexp = regexp.MustCompile(`(?i)(python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php|rust)`)
|
||||
|
||||
// 内容范围正则,其中的每个括号里的内容都在被引用,不能轻易修改
|
||||
var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)`)
|
||||
|
||||
// 分解Range
|
||||
@@ -180,7 +184,7 @@ var httpRequestTimestamp int64
|
||||
var httpRequestId int32 = 1_000_000
|
||||
|
||||
func httpRequestNextId() string {
|
||||
unixTime, unixTimeString := utils.UnixTimeMilliString()
|
||||
unixTime, unixTimeString := fasttime.Now().UnixMilliString()
|
||||
if unixTime > httpRequestTimestamp {
|
||||
atomic.StoreInt32(&httpRequestId, 1_000_000)
|
||||
httpRequestTimestamp = unixTime
|
||||
@@ -208,3 +212,13 @@ func httpAcceptEncoding(acceptEncodings string, encoding string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 跳转到某个URL
|
||||
func httpRedirect(writer http.ResponseWriter, req *http.Request, url string, code int) {
|
||||
if len(writer.Header().Get("Content-Type")) == 0 {
|
||||
// 设置Content-Type,是为了让页面不输出链接
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
}
|
||||
|
||||
http.Redirect(writer, req, url, code)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
if !canGoNext {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
@@ -421,3 +421,8 @@ func (this *HTTPRequest) WAFFingerprint() []byte {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableAccessLog 在当前请求中不使用访问日志
|
||||
func (this *HTTPRequest) DisableAccessLog() {
|
||||
this.disableLog = true
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/writers"
|
||||
_ "github.com/biessek/golang-ico"
|
||||
@@ -299,7 +300,7 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
|
||||
}
|
||||
}
|
||||
|
||||
var expiresAt = utils.UnixTime() + life
|
||||
var expiresAt = fasttime.Now().Unix() + life
|
||||
|
||||
if this.req.isLnRequest {
|
||||
// 返回上级节点过期时间
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestBaseListener_FindServer(t *testing.T) {
|
||||
{Name: types.String(i) + ".hello.com"},
|
||||
},
|
||||
}
|
||||
_ = server.Init()
|
||||
_ = server.Init(nil)
|
||||
listener.Group.Add(server)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"golang.org/x/net/http2"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -84,13 +82,7 @@ func (this *HTTPListener) Serve() error {
|
||||
if this.isHTTPS {
|
||||
this.httpServer.TLSConfig = this.buildTLSConfig()
|
||||
|
||||
// support http/2
|
||||
err := http2.ConfigureServer(this.httpServer, nil)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_LISTENER", "configure http2 error: "+err.Error())
|
||||
}
|
||||
|
||||
err = this.httpServer.ServeTLS(this.Listener, "", "")
|
||||
err := this.httpServer.ServeTLS(this.Listener, "", "")
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
@@ -114,6 +106,12 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
|
||||
// ServerHTTP 处理HTTP请求
|
||||
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && !globalServerConfig.HTTPAll.SupportsLowVersionHTTP && (rawReq.ProtoMajor < 1 /** 0.x **/ || (rawReq.ProtoMajor == 1 && rawReq.ProtoMinor == 0 /** 1.0 **/)) {
|
||||
http.Error(rawWriter, rawReq.Proto+" request is not supported.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 不支持Connect
|
||||
if rawReq.Method == http.MethodConnect {
|
||||
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
@@ -173,7 +171,10 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
if requestConn != nil {
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
clientConn.SetServerId(server.Id)
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,12 +86,6 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
|
||||
}**/
|
||||
this.lastConfig = node
|
||||
|
||||
// 初始化
|
||||
err, _ := node.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 所有的新地址
|
||||
groupAddrs := []string{}
|
||||
availableServerGroups := node.AvailableGroups()
|
||||
@@ -124,7 +118,7 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
|
||||
addr := group.FullAddr()
|
||||
listener, ok := this.listenersMap[addr]
|
||||
if ok {
|
||||
remotelogs.Println("LISTENER_MANAGER", "reload '"+this.prettyAddress(addr)+"'")
|
||||
// 不需要打印reload信息,防止日志数量过多
|
||||
listener.Reload(group)
|
||||
} else {
|
||||
remotelogs.Println("LISTENER_MANAGER", "listen '"+this.prettyAddress(addr)+"'")
|
||||
|
||||
@@ -75,7 +75,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
// 绑定连接和服务
|
||||
clientConn, ok := conn.(ClientConnInterface)
|
||||
if ok {
|
||||
clientConn.SetServerId(server.Id)
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return nil
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
} else {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
@@ -84,7 +87,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
if internalConn != nil {
|
||||
clientConn, ok = internalConn.(ClientConnInterface)
|
||||
if ok {
|
||||
clientConn.SetServerId(server.Id)
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return nil
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
}
|
||||
}
|
||||
@@ -114,14 +120,14 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
serverName = tlsConn.ConnectionState().ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
recordStat = true
|
||||
}
|
||||
}
|
||||
|
||||
// 统计
|
||||
if !recordStat {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
|
||||
@@ -176,7 +182,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
|
||||
// 记录流量
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -370,7 +370,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
|
||||
|
||||
// 统计
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
// 处理ControlMessage
|
||||
@@ -401,10 +401,10 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
|
||||
// 记录流量和带宽
|
||||
if server != nil {
|
||||
// 流量
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
|
||||
// 带宽
|
||||
stats.SharedBandwidthStatManager.Add(server.UserId, server.Id, int64(n), int64(n))
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, server.Id, int64(n), int64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
@@ -44,6 +43,7 @@ import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -55,6 +55,7 @@ var nodeConfigChangedNotify = make(chan bool, 8)
|
||||
var nodeConfigUpdatedAt int64
|
||||
var DaemonIsOn = false
|
||||
var DaemonPid = 0
|
||||
var nodeInstance *Node
|
||||
|
||||
// Node 节点
|
||||
type Node struct {
|
||||
@@ -75,16 +76,18 @@ type Node struct {
|
||||
lastAPINodeVersion int64
|
||||
lastAPINodeAddrs []string // 以前的API节点地址
|
||||
|
||||
lastTaskVersion int64
|
||||
lastTaskVersion int64
|
||||
lastUpdatingServerListId int64
|
||||
}
|
||||
|
||||
func NewNode() *Node {
|
||||
return &Node{
|
||||
nodeInstance = &Node{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
oldMaxThreads: -1,
|
||||
oldMaxCPU: -1,
|
||||
updatingServerMap: map[int64]*serverconfigs.ServerConfig{},
|
||||
}
|
||||
return nodeInstance
|
||||
}
|
||||
|
||||
// Test 检查配置
|
||||
@@ -135,6 +138,9 @@ func (this *Node) Start() {
|
||||
remotelogs.Error("NODE", "initialize ip library failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 调整系统参数
|
||||
this.checkSystem()
|
||||
|
||||
// 检查硬盘类型
|
||||
this.checkDisk()
|
||||
|
||||
@@ -191,7 +197,7 @@ func (this *Node) Start() {
|
||||
}
|
||||
teaconst.NodeId = nodeConfig.Id
|
||||
teaconst.NodeIdString = types.String(teaconst.NodeId)
|
||||
err, serverErrors := nodeConfig.Init()
|
||||
err, serverErrors := nodeConfig.Init(nil)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init node config failed: "+err.Error())
|
||||
return
|
||||
@@ -304,208 +310,6 @@ func (this *Node) InstallSystemService() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 循环
|
||||
func (this *Node) loop() error {
|
||||
var tr = trackers.Begin("CHECK_NODE_CONFIG_CHANGES")
|
||||
defer tr.End()
|
||||
|
||||
// 检查api.yaml是否存在
|
||||
var apiConfigFile = Tea.ConfigFile("api.yaml")
|
||||
_, err := os.Stat(apiConfigFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.New("create rpc client failed: " + err.Error())
|
||||
}
|
||||
|
||||
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{
|
||||
Version: this.lastTaskVersion,
|
||||
})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) && !Tea.IsTesting() {
|
||||
return nil
|
||||
}
|
||||
return errors.New("read node tasks failed: " + err.Error())
|
||||
}
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
err := this.execTask(rpcClient, task)
|
||||
if !this.finishTask(task.Id, task.Version, err) {
|
||||
// 防止失败的任务无法重试
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
|
||||
switch task.Type {
|
||||
case "ipItemChanged":
|
||||
// 防止阻塞
|
||||
select {
|
||||
case iplibrary.IPListUpdateNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
case "configChanged":
|
||||
if task.ServerId > 0 {
|
||||
return this.syncServerConfig(task.ServerId)
|
||||
}
|
||||
if !task.IsPrimary {
|
||||
// 我们等等主节点配置准备完毕
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
return this.syncConfig(task.Version)
|
||||
case "nodeVersionChanged":
|
||||
if !sharedUpgradeManager.IsInstalling() {
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
}
|
||||
case "scriptsChanged":
|
||||
err := this.reloadCommonScripts()
|
||||
if err != nil {
|
||||
return errors.New("reload common scripts failed: " + err.Error())
|
||||
}
|
||||
case "nodeLevelChanged":
|
||||
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(rpcClient.Context(), &pb.FindNodeLevelInfoRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.Level = levelInfoResp.Level
|
||||
}
|
||||
|
||||
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
|
||||
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
|
||||
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
|
||||
if err != nil {
|
||||
return errors.New("decode level info failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
}
|
||||
case "ddosProtectionChanged":
|
||||
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(rpcClient.Context(), &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if ddosProtectionConfig != nil && sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Warn("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
case "globalServerConfigChanged":
|
||||
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(rpcClient.Context(), &pb.FindNodeGlobalServerConfigRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.GlobalServerConfigJSON) > 0 {
|
||||
var globalServerConfig = serverconfigs.DefaultGlobalServerConfig()
|
||||
err = json.Unmarshal(resp.GlobalServerConfigJSON, globalServerConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode global server config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if globalServerConfig != nil {
|
||||
err = globalServerConfig.Init()
|
||||
if err != nil {
|
||||
return errors.New("validate global server config failed: " + err.Error())
|
||||
}
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.GlobalServerConfig = globalServerConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
case "userServersStateChanged":
|
||||
if task.UserId > 0 {
|
||||
resp, err := rpcClient.UserRPC.CheckUserServersState(rpcClient.Context(), &pb.CheckUserServersStateRequest{UserId: task.UserId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SharedUserManager.UpdateUserServersIsEnabled(task.UserId, resp.IsEnabled)
|
||||
|
||||
if resp.IsEnabled {
|
||||
err = this.syncUserServersConfig(task.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标记任务完成
|
||||
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
|
||||
if taskId <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Debug("NODE", "create rpc client failed: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
var isOk = taskErr == nil
|
||||
if isOk && taskVersion > this.lastTaskVersion {
|
||||
this.lastTaskVersion = taskVersion
|
||||
}
|
||||
|
||||
var errMsg = ""
|
||||
if taskErr != nil {
|
||||
errMsg = taskErr.Error()
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: taskId,
|
||||
IsOk: isOk,
|
||||
Error: errMsg,
|
||||
})
|
||||
success = err == nil
|
||||
|
||||
if err != nil {
|
||||
// 连接错误不需要上报到服务中心
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("NODE", "report task done failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE", "report task done failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
// 读取API配置
|
||||
func (this *Node) syncConfig(taskVersion int64) error {
|
||||
this.locker.Lock()
|
||||
@@ -539,6 +343,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
Version: -1, // 更新所有版本
|
||||
Compress: true,
|
||||
NodeTaskVersion: taskVersion,
|
||||
UseDataMap: true,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("read config from rpc failed: " + err.Error())
|
||||
@@ -589,7 +394,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err, serverErrors := nodeConfig.Init()
|
||||
err, serverErrors := nodeConfig.Init(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -601,9 +406,9 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
|
||||
// 刷新配置
|
||||
if this.isLoaded {
|
||||
remotelogs.Println("NODE", "reloading config ...")
|
||||
remotelogs.Println("NODE", "reloading node config ...")
|
||||
} else {
|
||||
remotelogs.Println("NODE", "loading config ...")
|
||||
remotelogs.Println("NODE", "loading node config ...")
|
||||
}
|
||||
|
||||
this.onReload(nodeConfig, true)
|
||||
@@ -617,6 +422,9 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
|
||||
this.isLoaded = true
|
||||
|
||||
// 整体更新不需要再更新单个服务
|
||||
this.updatingServerMap = map[int64]*serverconfigs.ServerConfig{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -691,7 +499,7 @@ func (this *Node) startSyncTimer() {
|
||||
for {
|
||||
select {
|
||||
case <-taskTicker.C: // 定期执行
|
||||
err := this.loop()
|
||||
err := this.loopTasks()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "sync config error: "+err.Error())
|
||||
continue
|
||||
@@ -699,7 +507,7 @@ func (this *Node) startSyncTimer() {
|
||||
case <-serverChangeTicker.C: // 服务变化
|
||||
this.reloadServer()
|
||||
case <-nodeTaskNotify: // 有新的更新任务
|
||||
err := this.loop()
|
||||
err := this.loopTasks()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "sync config error: "+err.Error())
|
||||
continue
|
||||
@@ -1187,6 +995,9 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig, reloadAll bool) {
|
||||
// API Node地址,这里不限制是否为空,因为在为空时仍然要有对应的处理
|
||||
this.changeAPINodeAddrs(config.APINodeAddrs)
|
||||
}
|
||||
|
||||
// 刷新IP库
|
||||
this.reloadIPLibrary()
|
||||
}
|
||||
|
||||
// reload server config
|
||||
@@ -1194,7 +1005,9 @@ func (this *Node) reloadServer() {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if len(this.updatingServerMap) > 0 {
|
||||
var countUpdatingServers = len(this.updatingServerMap)
|
||||
const maxPrintServers = 10
|
||||
if countUpdatingServers > 0 {
|
||||
var updatingServerMap = this.updatingServerMap
|
||||
this.updatingServerMap = map[int64]*serverconfigs.ServerConfig{}
|
||||
newNodeConfig, err := nodeconfigs.CloneNodeConfig(sharedNodeConfig)
|
||||
@@ -1204,13 +1017,23 @@ func (this *Node) reloadServer() {
|
||||
}
|
||||
for serverId, serverConfig := range updatingServerMap {
|
||||
if serverConfig != nil {
|
||||
if countUpdatingServers < maxPrintServers {
|
||||
remotelogs.Debug("NODE", "load server '"+types.String(serverId)+"'")
|
||||
}
|
||||
newNodeConfig.AddServer(serverConfig)
|
||||
} else {
|
||||
if countUpdatingServers < maxPrintServers {
|
||||
remotelogs.Debug("NODE", "remove server '"+types.String(serverId)+"'")
|
||||
}
|
||||
newNodeConfig.RemoveServer(serverId)
|
||||
}
|
||||
}
|
||||
|
||||
err, serverErrors := newNodeConfig.Init()
|
||||
if countUpdatingServers >= maxPrintServers {
|
||||
remotelogs.Debug("NODE", "reload "+types.String(countUpdatingServers)+" servers")
|
||||
}
|
||||
|
||||
err, serverErrors := newNodeConfig.Init(nil)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "apply server config error: "+err.Error())
|
||||
return
|
||||
@@ -1230,6 +1053,56 @@ func (this *Node) reloadServer() {
|
||||
}
|
||||
}
|
||||
|
||||
// 检查系统
|
||||
func (this *Node) checkSystem() {
|
||||
if runtime.GOOS != "linux" || os.Getgid() != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type variable struct {
|
||||
name string
|
||||
minValue int
|
||||
maxValue int
|
||||
}
|
||||
|
||||
const dir = "/proc/sys"
|
||||
|
||||
for _, v := range []variable{
|
||||
{name: "net.core.somaxconn", minValue: 2048},
|
||||
{name: "net.ipv4.tcp_max_syn_backlog", minValue: 2048},
|
||||
{name: "net.core.netdev_max_backlog", minValue: 4096},
|
||||
{name: "net.ipv4.tcp_fin_timeout", maxValue: 10},
|
||||
{name: "net.ipv4.tcp_max_tw_buckets", minValue: 65535},
|
||||
{name: "net.core.rmem_default", minValue: 4 << 20},
|
||||
{name: "net.core.wmem_default", minValue: 4 << 20},
|
||||
{name: "net.core.rmem_max", minValue: 32 << 20},
|
||||
{name: "net.core.wmem_max", minValue: 32 << 20},
|
||||
} {
|
||||
var path = dir + "/" + strings.Replace(v.name, ".", "/", -1)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var oldValue = types.Int(string(data))
|
||||
if v.minValue > 0 && oldValue < v.minValue {
|
||||
err = os.WriteFile(path, []byte(types.String(v.minValue)), 0666)
|
||||
if err == nil {
|
||||
remotelogs.Println("NODE", "change kernel parameter '"+v.name+"' from '"+types.String(oldValue)+"' to '"+types.String(v.minValue)+"'")
|
||||
}
|
||||
} else if v.maxValue > 0 && oldValue > v.maxValue {
|
||||
err = os.WriteFile(path, []byte(types.String(v.maxValue)), 0666)
|
||||
if err == nil {
|
||||
remotelogs.Println("NODE", "change kernel parameter '"+v.name+"' from '"+types.String(oldValue)+"' to '"+types.String(v.maxValue)+"'")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查硬盘
|
||||
func (this *Node) checkDisk() {
|
||||
if runtime.GOOS != "linux" {
|
||||
|
||||
@@ -7,3 +7,11 @@ package nodes
|
||||
func (this *Node) reloadCommonScripts() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) reloadIPLibrary() {
|
||||
|
||||
}
|
||||
|
||||
func (this *Node) notifyPlusChange() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
|
||||
if minFreeMemory > 1<<30 {
|
||||
minFreeMemory = 1 << 30
|
||||
}
|
||||
if stat.Free < minFreeMemory {
|
||||
if stat.Available > 0 && stat.Available < minFreeMemory {
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
|
||||
355
internal/nodes/node_tasks.go
Normal file
355
internal/nodes/node_tasks.go
Normal file
@@ -0,0 +1,355 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 循环
|
||||
func (this *Node) loopTasks() error {
|
||||
var tr = trackers.Begin("CHECK_NODE_CONFIG_CHANGES")
|
||||
defer tr.End()
|
||||
|
||||
// 检查api.yaml是否存在
|
||||
var apiConfigFile = Tea.ConfigFile("api.yaml")
|
||||
_, err := os.Stat(apiConfigFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.New("create rpc client failed: " + err.Error())
|
||||
}
|
||||
|
||||
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{
|
||||
Version: this.lastTaskVersion,
|
||||
})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) && !Tea.IsTesting() {
|
||||
return nil
|
||||
}
|
||||
return errors.New("read node tasks failed: " + err.Error())
|
||||
}
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
err := this.execTask(rpcClient, task)
|
||||
if !this.finishTask(task.Id, task.Version, err) {
|
||||
// 防止失败的任务无法重试
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (this *Node) execTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
|
||||
var err error
|
||||
switch task.Type {
|
||||
case "ipItemChanged":
|
||||
err = this.execIPItemChangedTask()
|
||||
case "configChanged":
|
||||
err = this.execConfigChangedTask(task)
|
||||
case "nodeVersionChanged":
|
||||
err = this.execNodeVersionChangedTask()
|
||||
case "scriptsChanged":
|
||||
err = this.execScriptsChangedTask()
|
||||
case "nodeLevelChanged":
|
||||
err = this.execNodeLevelChangedTask(rpcClient)
|
||||
case "ddosProtectionChanged":
|
||||
err = this.execDDoSProtectionChangedTask(rpcClient)
|
||||
case "globalServerConfigChanged":
|
||||
err = this.execGlobalServerConfigChangedTask(rpcClient)
|
||||
case "userServersStateChanged":
|
||||
err = this.execUserServersStateChangedTask(rpcClient, task)
|
||||
case "uamPolicyChanged":
|
||||
err = this.execUAMPolicyChangedTask(rpcClient)
|
||||
case "updatingServers":
|
||||
err = this.execUpdatingServersTask(rpcClient)
|
||||
case "plusChanged":
|
||||
err = this.notifyPlusChange()
|
||||
default:
|
||||
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新IP条目变更
|
||||
func (this *Node) execIPItemChangedTask() error {
|
||||
// 防止阻塞
|
||||
select {
|
||||
case iplibrary.IPListUpdateNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新节点配置变更
|
||||
func (this *Node) execConfigChangedTask(task *pb.NodeTask) error {
|
||||
if task.ServerId > 0 {
|
||||
return this.syncServerConfig(task.ServerId)
|
||||
}
|
||||
if !task.IsPrimary {
|
||||
// 我们等等主节点配置准备完毕
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
return this.syncConfig(task.Version)
|
||||
}
|
||||
|
||||
// 节点程序版本号变更
|
||||
func (this *Node) execNodeVersionChangedTask() error {
|
||||
if !sharedUpgradeManager.IsInstalling() {
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 脚本库变更
|
||||
func (this *Node) execScriptsChangedTask() error {
|
||||
err := this.reloadCommonScripts()
|
||||
if err != nil {
|
||||
return errors.New("reload common scripts failed: " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 节点级别变更
|
||||
func (this *Node) execNodeLevelChangedTask(rpcClient *rpc.RPCClient) error {
|
||||
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(rpcClient.Context(), &pb.FindNodeLevelInfoRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.Level = levelInfoResp.Level
|
||||
}
|
||||
|
||||
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
|
||||
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
|
||||
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
|
||||
if err != nil {
|
||||
return errors.New("decode level info failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UAM策略变更
|
||||
func (this *Node) execUAMPolicyChangedTask(rpcClient *rpc.RPCClient) error {
|
||||
remotelogs.Println("NODE", "updating uam policies ...")
|
||||
resp, err := rpcClient.NodeRPC.FindNodeUAMPolicies(rpcClient.Context(), &pb.FindNodeUAMPoliciesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var uamPolicyMap = map[int64]*nodeconfigs.UAMPolicy{}
|
||||
for _, policy := range resp.UamPolicies {
|
||||
if len(policy.UamPolicyJSON) > 0 {
|
||||
var uamPolicy = &nodeconfigs.UAMPolicy{}
|
||||
err = json.Unmarshal(policy.UamPolicyJSON, uamPolicy)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "decode uam policy failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
err = uamPolicy.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "initialize uam policy failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
uamPolicyMap[policy.NodeClusterId] = uamPolicy
|
||||
}
|
||||
}
|
||||
sharedNodeConfig.UpdateUAMPolicies(uamPolicyMap)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DDoS配置变更
|
||||
func (this *Node) execDDoSProtectionChangedTask(rpcClient *rpc.RPCClient) error {
|
||||
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(rpcClient.Context(), &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if ddosProtectionConfig != nil && sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Warn("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 服务全局配置变更
|
||||
func (this *Node) execGlobalServerConfigChangedTask(rpcClient *rpc.RPCClient) error {
|
||||
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(rpcClient.Context(), &pb.FindNodeGlobalServerConfigRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.GlobalServerConfigJSON) > 0 {
|
||||
var globalServerConfig = serverconfigs.DefaultGlobalServerConfig()
|
||||
err = json.Unmarshal(resp.GlobalServerConfigJSON, globalServerConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode global server config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if globalServerConfig != nil {
|
||||
err = globalServerConfig.Init()
|
||||
if err != nil {
|
||||
return errors.New("validate global server config failed: " + err.Error())
|
||||
}
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.GlobalServerConfig = globalServerConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 单个用户服务状态变更
|
||||
func (this *Node) execUserServersStateChangedTask(rpcClient *rpc.RPCClient, task *pb.NodeTask) error {
|
||||
if task.UserId > 0 {
|
||||
resp, err := rpcClient.UserRPC.CheckUserServersState(rpcClient.Context(), &pb.CheckUserServersStateRequest{UserId: task.UserId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SharedUserManager.UpdateUserServersIsEnabled(task.UserId, resp.IsEnabled)
|
||||
|
||||
if resp.IsEnabled {
|
||||
err = this.syncUserServersConfig(task.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新一组服务列表
|
||||
func (this *Node) execUpdatingServersTask(rpcClient *rpc.RPCClient) error {
|
||||
if this.lastUpdatingServerListId <= 0 {
|
||||
this.lastUpdatingServerListId = sharedNodeConfig.UpdatingServerListId
|
||||
}
|
||||
|
||||
resp, err := rpcClient.UpdatingServerListRPC.FindUpdatingServerLists(rpcClient.Context(), &pb.FindUpdatingServerListsRequest{LastId: this.lastUpdatingServerListId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.MaxId <= 0 || len(resp.ServersJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var serverConfigs = []*serverconfigs.ServerConfig{}
|
||||
err = json.Unmarshal(resp.ServersJSON, &serverConfigs)
|
||||
if err != nil {
|
||||
return errors.New("decode server configs failed: " + err.Error())
|
||||
}
|
||||
|
||||
if resp.MaxId > this.lastUpdatingServerListId {
|
||||
this.lastUpdatingServerListId = resp.MaxId
|
||||
}
|
||||
|
||||
if len(serverConfigs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
for _, serverConfig := range serverConfigs {
|
||||
if serverConfig == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if serverConfig.IsOn {
|
||||
this.updatingServerMap[serverConfig.Id] = serverConfig
|
||||
} else {
|
||||
this.updatingServerMap[serverConfig.Id] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标记任务完成
|
||||
func (this *Node) finishTask(taskId int64, taskVersion int64, taskErr error) (success bool) {
|
||||
if taskId <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Debug("NODE", "create rpc client failed: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
var isOk = taskErr == nil
|
||||
if isOk && taskVersion > this.lastTaskVersion {
|
||||
this.lastTaskVersion = taskVersion
|
||||
}
|
||||
|
||||
var errMsg = ""
|
||||
if taskErr != nil {
|
||||
errMsg = taskErr.Error()
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: taskId,
|
||||
IsOk: isOk,
|
||||
Error: errMsg,
|
||||
})
|
||||
success = err == nil
|
||||
|
||||
if err != nil {
|
||||
// 连接错误不需要上报到服务中心
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("NODE", "report task done failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE", "report task done failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func init() {
|
||||
err := uploadLogs()
|
||||
tr.End()
|
||||
if err != nil {
|
||||
logs.Println("[LOG]" + err.Error())
|
||||
logs.Println("[LOG]upload logs failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -191,7 +191,7 @@ func ServerError(serverId int64, tag string, description string, logType nodecon
|
||||
if len(params) > 0 {
|
||||
p, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
logs.Println("[LOG]" + err.Error())
|
||||
logs.Println("[LOG]ServerError(): json encode failed: " + err.Error())
|
||||
} else {
|
||||
paramsJSON = p
|
||||
}
|
||||
@@ -223,7 +223,7 @@ func ServerSuccess(serverId int64, tag string, description string, logType nodec
|
||||
if len(params) > 0 {
|
||||
p, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
logs.Println("[LOG]" + err.Error())
|
||||
logs.Println("[LOG]ServerSuccess(): json encode failed: " + err.Error())
|
||||
} else {
|
||||
paramsJSON = p
|
||||
}
|
||||
@@ -255,7 +255,7 @@ func ServerLog(serverId int64, tag string, description string, logType nodeconfi
|
||||
if len(params) > 0 {
|
||||
p, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
logs.Println("[LOG]" + err.Error())
|
||||
logs.Println("[LOG]ServerLog(): json encode failed: " + err.Error())
|
||||
} else {
|
||||
paramsJSON = p
|
||||
}
|
||||
|
||||
@@ -51,6 +51,8 @@ type RPCClient struct {
|
||||
ScriptRPC pb.ScriptServiceClient
|
||||
UserRPC pb.UserServiceClient
|
||||
ClientAgentIPRPC pb.ClientAgentIPServiceClient
|
||||
AuthorityKeyRPC pb.AuthorityKeyServiceClient
|
||||
UpdatingServerListRPC pb.UpdatingServerListServiceClient
|
||||
}
|
||||
|
||||
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
@@ -85,6 +87,8 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
client.ScriptRPC = pb.NewScriptServiceClient(client)
|
||||
client.UserRPC = pb.NewUserServiceClient(client)
|
||||
client.ClientAgentIPRPC = pb.NewClientAgentIPServiceClient(client)
|
||||
client.AuthorityKeyRPC = pb.NewAuthorityKeyServiceClient(client)
|
||||
client.UpdatingServerListRPC = pb.NewUpdatingServerListServiceClient(client)
|
||||
|
||||
err := client.init()
|
||||
if err != nil {
|
||||
@@ -231,8 +235,8 @@ func (this *RPCClient) init() error {
|
||||
}
|
||||
var conn *grpc.ClientConn
|
||||
var callOptions = grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(128*1024*1024),
|
||||
grpc.MaxCallSendMsgSize(128*1024*1024),
|
||||
grpc.MaxCallRecvMsgSize(512<<20),
|
||||
grpc.MaxCallSendMsgSize(512<<20),
|
||||
grpc.UseCompressor(gzip.Name),
|
||||
)
|
||||
if u.Scheme == "http" {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
@@ -10,9 +11,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -31,23 +35,38 @@ func init() {
|
||||
SharedBandwidthStatManager.Start()
|
||||
})
|
||||
})
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedBandwidthStatManager.Cancel()
|
||||
|
||||
err := SharedBandwidthStatManager.Save()
|
||||
if err != nil {
|
||||
remotelogs.Error("STAT", "save bandwidth stats failed: "+err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type BandwidthStat struct {
|
||||
Day string
|
||||
TimeAt string
|
||||
UserId int64
|
||||
ServerId int64
|
||||
Day string `json:"day"`
|
||||
TimeAt string `json:"timeAt"`
|
||||
UserId int64 `json:"userId"`
|
||||
ServerId int64 `json:"serverId"`
|
||||
|
||||
CurrentBytes int64
|
||||
CurrentTimestamp int64
|
||||
MaxBytes int64
|
||||
TotalBytes int64
|
||||
CurrentBytes int64 `json:"currentBytes"`
|
||||
CurrentTimestamp int64 `json:"currentTimestamp"`
|
||||
MaxBytes int64 `json:"maxBytes"`
|
||||
TotalBytes int64 `json:"totalBytes"`
|
||||
|
||||
CachedBytes int64 `json:"cachedBytes"`
|
||||
AttackBytes int64 `json:"attackBytes"`
|
||||
CountRequests int64 `json:"countRequests"`
|
||||
CountCachedRequests int64 `json:"countCachedRequests"`
|
||||
CountAttackRequests int64 `json:"countAttackRequests"`
|
||||
}
|
||||
|
||||
// BandwidthStatManager 服务带宽统计
|
||||
type BandwidthStatManager struct {
|
||||
m map[string]*BandwidthStat // key => *BandwidthStat
|
||||
m map[string]*BandwidthStat // serverId@day@time => *BandwidthStat
|
||||
|
||||
pbStats []*pb.ServerBandwidthStat
|
||||
|
||||
@@ -55,16 +74,25 @@ type BandwidthStatManager struct {
|
||||
|
||||
ticker *time.Ticker
|
||||
locker sync.Mutex
|
||||
|
||||
cacheFile string // 上一次的缓存文件
|
||||
}
|
||||
|
||||
func NewBandwidthStatManager() *BandwidthStatManager {
|
||||
return &BandwidthStatManager{
|
||||
m: map[string]*BandwidthStat{},
|
||||
ticker: time.NewTicker(1 * time.Minute), // 时间小于1分钟是为了更快速地上传结果
|
||||
m: map[string]*BandwidthStat{},
|
||||
ticker: time.NewTicker(1 * time.Minute), // 时间小于1分钟是为了更快速地上传结果
|
||||
cacheFile: Tea.Root + "/data/bandwidth.dat",
|
||||
}
|
||||
}
|
||||
|
||||
func (this *BandwidthStatManager) Start() {
|
||||
// 从上次数据中恢复
|
||||
this.locker.Lock()
|
||||
this.recover()
|
||||
this.locker.Unlock()
|
||||
|
||||
// 循环上报数据
|
||||
for range this.ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil && !rpc.IsConnError(err) {
|
||||
@@ -82,7 +110,7 @@ func (this *BandwidthStatManager) Loop() error {
|
||||
|
||||
var now = time.Now()
|
||||
var day = timeutil.Format("Ymd", now)
|
||||
var currentTime = timeutil.FormatTime("Hi", now.Unix()/300*300)
|
||||
var currentTime = timeutil.FormatTime("Hi", now.Unix()/300*300) // 300s = 5 minutes
|
||||
|
||||
if this.lastTime == currentTime {
|
||||
return nil
|
||||
@@ -106,15 +134,28 @@ func (this *BandwidthStatManager) Loop() error {
|
||||
this.locker.Lock()
|
||||
for key, stat := range this.m {
|
||||
if stat.Day < day || stat.TimeAt < currentTime {
|
||||
// 防止数据出现错误
|
||||
if stat.CachedBytes > stat.TotalBytes {
|
||||
stat.CachedBytes = stat.TotalBytes
|
||||
}
|
||||
if stat.AttackBytes > stat.TotalBytes {
|
||||
stat.AttackBytes = stat.TotalBytes
|
||||
}
|
||||
|
||||
pbStats = append(pbStats, &pb.ServerBandwidthStat{
|
||||
Id: 0,
|
||||
UserId: stat.UserId,
|
||||
ServerId: stat.ServerId,
|
||||
Day: stat.Day,
|
||||
TimeAt: stat.TimeAt,
|
||||
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
|
||||
TotalBytes: stat.TotalBytes,
|
||||
NodeRegionId: regionId,
|
||||
Id: 0,
|
||||
UserId: stat.UserId,
|
||||
ServerId: stat.ServerId,
|
||||
Day: stat.Day,
|
||||
TimeAt: stat.TimeAt,
|
||||
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
|
||||
TotalBytes: stat.TotalBytes,
|
||||
CachedBytes: stat.CachedBytes,
|
||||
AttackBytes: stat.AttackBytes,
|
||||
CountRequests: stat.CountRequests,
|
||||
CountCachedRequests: stat.CountCachedRequests,
|
||||
CountAttackRequests: stat.CountAttackRequests,
|
||||
NodeRegionId: regionId,
|
||||
})
|
||||
delete(this.m, key)
|
||||
}
|
||||
@@ -138,16 +179,16 @@ func (this *BandwidthStatManager) Loop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add 添加带宽数据
|
||||
func (this *BandwidthStatManager) Add(userId int64, serverId int64, peekBytes int64, totalBytes int64) {
|
||||
// AddBandwidth 添加带宽数据
|
||||
func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, peekBytes int64, totalBytes int64) {
|
||||
if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) {
|
||||
return
|
||||
}
|
||||
|
||||
var now = time.Now()
|
||||
var now = fasttime.Now()
|
||||
var timestamp = now.Unix() / bandwidthTimestampDelim * bandwidthTimestampDelim // 将时间戳均分成N等份
|
||||
var day = timeutil.Format("Ymd", now)
|
||||
var timeAt = timeutil.FormatTime("Hi", now.Unix()/300*300)
|
||||
var day = now.Ymd()
|
||||
var timeAt = now.Round5Hi()
|
||||
var key = types.String(serverId) + "@" + day + "@" + timeAt
|
||||
|
||||
// 增加TCP Header尺寸,这里默认MTU为1500,且默认为IPv4
|
||||
@@ -188,6 +229,25 @@ func (this *BandwidthStatManager) Add(userId int64, serverId int64, peekBytes in
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// AddTraffic 添加请求数据
|
||||
func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
|
||||
var now = fasttime.Now()
|
||||
var day = now.Ymd()
|
||||
var timeAt = now.Round5Hi()
|
||||
var key = types.String(serverId) + "@" + day + "@" + timeAt
|
||||
this.locker.Lock()
|
||||
// 只有有记录了才会添加
|
||||
stat, ok := this.m[key]
|
||||
if ok {
|
||||
stat.CachedBytes += cachedBytes
|
||||
stat.CountRequests += countRequests
|
||||
stat.CountCachedRequests += countCachedRequests
|
||||
stat.CountAttackRequests += countAttacks
|
||||
stat.AttackBytes += attackBytes
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *BandwidthStatManager) Inspect() {
|
||||
this.locker.Lock()
|
||||
logs.PrintAsJSON(this.m)
|
||||
@@ -205,3 +265,50 @@ func (this *BandwidthStatManager) Map() map[int64]int64 /** serverId => max byte
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Save 保存到本地磁盘
|
||||
func (this *BandwidthStatManager) Save() error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
data, err := json.Marshal(this.m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = os.Remove(this.cacheFile)
|
||||
return os.WriteFile(this.cacheFile, data, 0666)
|
||||
}
|
||||
|
||||
// Cancel 取消上传
|
||||
func (this *BandwidthStatManager) Cancel() {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
|
||||
// 从本地缓存文件中恢复数据
|
||||
func (this *BandwidthStatManager) recover() {
|
||||
cacheData, err := os.ReadFile(this.cacheFile)
|
||||
if err == nil {
|
||||
var m = map[string]*BandwidthStat{}
|
||||
err = json.Unmarshal(cacheData, &m)
|
||||
if err == nil && len(m) > 0 {
|
||||
var lastTime = ""
|
||||
for _, stat := range m {
|
||||
if stat.Day != fasttime.Now().Ymd() {
|
||||
continue
|
||||
}
|
||||
if len(lastTime) == 0 || stat.TimeAt > lastTime {
|
||||
lastTime = stat.TimeAt
|
||||
}
|
||||
}
|
||||
if len(lastTime) > 0 {
|
||||
var availableTime = timeutil.FormatTime("Hi", (time.Now().Unix()-300) /** 只保留5分钟的 **/ /300*300) // 300s = 5 minutes
|
||||
if lastTime >= availableTime {
|
||||
this.m = m
|
||||
this.lastTime = lastTime
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = os.Remove(this.cacheFile)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,31 +3,94 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBandwidthStatManager_Add(t *testing.T) {
|
||||
var manager = stats.NewBandwidthStatManager()
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
time.Sleep(1 * time.Second)
|
||||
manager.Add(1, 1, 85, 85)
|
||||
manager.AddBandwidth(1, 1, 85, 85)
|
||||
time.Sleep(1 * time.Second)
|
||||
manager.Add(1, 1, 25, 25)
|
||||
manager.Add(1, 1, 75, 75)
|
||||
manager.AddBandwidth(1, 1, 25, 25)
|
||||
manager.AddBandwidth(1, 1, 75, 75)
|
||||
manager.Inspect()
|
||||
}
|
||||
|
||||
func TestBandwidthStatManager_Loop(t *testing.T) {
|
||||
var manager = stats.NewBandwidthStatManager()
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.Add(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
manager.AddBandwidth(1, 1, 10, 10)
|
||||
err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBandwidthStatManager_Add(b *testing.B) {
|
||||
var manager = stats.NewBandwidthStatManager()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
var i int
|
||||
for pb.Next() {
|
||||
i++
|
||||
manager.AddBandwidth(1, int64(i%100), 10, 10)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBandwidthStatManager_Slice(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var pbStats = []*pb.ServerBandwidthStat{}
|
||||
for j := 0; j < 100; j++ {
|
||||
var stat = &stats.BandwidthStat{}
|
||||
pbStats = append(pbStats, &pb.ServerBandwidthStat{
|
||||
Id: 0,
|
||||
UserId: stat.UserId,
|
||||
ServerId: stat.ServerId,
|
||||
Day: stat.Day,
|
||||
TimeAt: stat.TimeAt,
|
||||
Bytes: stat.MaxBytes / 2,
|
||||
TotalBytes: stat.TotalBytes,
|
||||
CachedBytes: stat.CachedBytes,
|
||||
AttackBytes: stat.AttackBytes,
|
||||
CountRequests: stat.CountRequests,
|
||||
CountCachedRequests: stat.CountCachedRequests,
|
||||
CountAttackRequests: stat.CountAttackRequests,
|
||||
NodeRegionId: 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBandwidthStatManager_Slice2(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var statsSlice = []*stats.BandwidthStat{}
|
||||
for j := 0; j < 100; j++ {
|
||||
var stat = &stats.BandwidthStat{}
|
||||
statsSlice = append(statsSlice, stat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBandwidthStatManager_Slice3(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var statsSlice = make([]*stats.BandwidthStat, 2000)
|
||||
for j := 0; j < 100; j++ {
|
||||
var stat = &stats.BandwidthStat{}
|
||||
statsSlice[j] = stat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ type HTTPRequestStatManager struct {
|
||||
totalAttackRequests int64
|
||||
|
||||
locker sync.Mutex
|
||||
|
||||
monitorTicker *time.Ticker
|
||||
uploadTicker *time.Ticker
|
||||
}
|
||||
|
||||
// NewHTTPRequestStatManager 获取新对象
|
||||
@@ -77,12 +80,12 @@ func NewHTTPRequestStatManager() *HTTPRequestStatManager {
|
||||
// Start 启动
|
||||
func (this *HTTPRequestStatManager) Start() {
|
||||
// 上传请求总数
|
||||
var monitorTicker = time.NewTicker(1 * time.Minute)
|
||||
this.monitorTicker = time.NewTicker(1 * time.Minute)
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
monitorTicker.Stop()
|
||||
this.monitorTicker.Stop()
|
||||
})
|
||||
goman.New(func() {
|
||||
for range monitorTicker.C {
|
||||
for range this.monitorTicker.C {
|
||||
if this.totalAttackRequests > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAttackRequests, maps.Map{"total": this.totalAttackRequests})
|
||||
this.totalAttackRequests = 0
|
||||
@@ -90,19 +93,19 @@ func (this *HTTPRequestStatManager) Start() {
|
||||
}
|
||||
})
|
||||
|
||||
var uploadTicker = time.NewTicker(30 * time.Minute)
|
||||
this.uploadTicker = time.NewTicker(30 * time.Minute)
|
||||
if Tea.IsTesting() {
|
||||
uploadTicker = time.NewTicker(10 * time.Second) // 在测试环境下缩短Ticker时间,以方便我们调试
|
||||
this.uploadTicker = time.NewTicker(10 * time.Second) // 在测试环境下缩短Ticker时间,以方便我们调试
|
||||
}
|
||||
remotelogs.Println("HTTP_REQUEST_STAT_MANAGER", "start ...")
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("HTTP_REQUEST_STAT_MANAGER", "quit")
|
||||
uploadTicker.Stop()
|
||||
this.uploadTicker.Stop()
|
||||
})
|
||||
|
||||
// 上传Ticker
|
||||
goman.New(func() {
|
||||
for range uploadTicker.C {
|
||||
for range this.uploadTicker.C {
|
||||
var tr = trackers.Begin("UPLOAD_REQUEST_STATS")
|
||||
err := this.Upload()
|
||||
tr.End()
|
||||
@@ -204,34 +207,38 @@ func (this *HTTPRequestStatManager) Loop() error {
|
||||
if len(pieces) < 4 {
|
||||
return nil
|
||||
}
|
||||
var serverId = pieces[0]
|
||||
var serverIdString = pieces[0]
|
||||
var ip = pieces[1]
|
||||
|
||||
var result = iplib.LookupIP(ip)
|
||||
if result != nil && result.IsOk() {
|
||||
var key = serverId + "@" + types.String(result.CountryId()) + "@" + types.String(result.ProvinceId()) + "@" + types.String(result.CityId())
|
||||
this.locker.Lock()
|
||||
stat, ok := this.cityMap[key]
|
||||
if !ok {
|
||||
// 检查数量
|
||||
if this.serverCityCountMap[key] > 128 { // 限制单个服务的城市数量,防止数量过多
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
this.serverCityCountMap[key]++ // 需要放在限制之后,因为使用的是int16
|
||||
if result.CountryId() > 0 {
|
||||
var key = serverIdString + "@" + types.String(result.CountryId()) + "@" + types.String(result.ProvinceId()) + "@" + types.String(result.CityId())
|
||||
stat, ok := this.cityMap[key]
|
||||
if !ok {
|
||||
// 检查数量
|
||||
if this.serverCityCountMap[serverIdString] > 128 { // 限制单个服务的城市数量,防止数量过多
|
||||
this.locker.Unlock()
|
||||
return nil
|
||||
}
|
||||
this.serverCityCountMap[serverIdString]++ // 需要放在限制之后,因为使用的是int16
|
||||
|
||||
stat = &StatItem{}
|
||||
this.cityMap[key] = stat
|
||||
}
|
||||
stat.Bytes += types.Int64(pieces[2])
|
||||
stat.CountRequests++
|
||||
if types.Int8(pieces[3]) == 1 {
|
||||
stat.AttackBytes += types.Int64(pieces[2])
|
||||
stat.CountAttackRequests++
|
||||
stat = &StatItem{}
|
||||
this.cityMap[key] = stat
|
||||
}
|
||||
stat.Bytes += types.Int64(pieces[2])
|
||||
stat.CountRequests++
|
||||
if types.Int8(pieces[3]) == 1 {
|
||||
stat.AttackBytes += types.Int64(pieces[2])
|
||||
stat.CountAttackRequests++
|
||||
}
|
||||
}
|
||||
|
||||
if result.ProviderId() > 0 {
|
||||
this.providerMap[serverId+"@"+types.String(result.ProviderId())]++
|
||||
this.providerMap[serverIdString+"@"+types.String(result.ProviderId())]++
|
||||
} else if utils.IsLocalIP(ip) { // 局域网IP
|
||||
this.providerMap[serverIdString+"@258"]++
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
@@ -240,7 +247,7 @@ func (this *HTTPRequestStatManager) Loop() error {
|
||||
if atIndex < 0 {
|
||||
return nil
|
||||
}
|
||||
var serverId = userAgentString[:atIndex]
|
||||
var serverIdString = userAgentString[:atIndex]
|
||||
var userAgent = userAgentString[atIndex+1:]
|
||||
|
||||
var result = SharedUserAgentParser.Parse(userAgent)
|
||||
@@ -252,11 +259,11 @@ func (this *HTTPRequestStatManager) Loop() error {
|
||||
}
|
||||
this.locker.Lock()
|
||||
|
||||
var systemKey = serverId + "@" + osInfo.Name + "@" + osInfo.Version
|
||||
var systemKey = serverIdString + "@" + osInfo.Name + "@" + osInfo.Version
|
||||
_, ok := this.systemMap[systemKey]
|
||||
if !ok {
|
||||
if this.serverSystemCountMap[serverId] < 128 { // 限制最大数据,防止攻击
|
||||
this.serverSystemCountMap[serverId]++
|
||||
if this.serverSystemCountMap[serverIdString] < 128 { // 限制最大数据,防止攻击
|
||||
this.serverSystemCountMap[serverIdString]++
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
@@ -274,11 +281,11 @@ func (this *HTTPRequestStatManager) Loop() error {
|
||||
}
|
||||
this.locker.Lock()
|
||||
|
||||
var browserKey = serverId + "@" + browser + "@" + browserVersion
|
||||
var browserKey = serverIdString + "@" + browser + "@" + browserVersion
|
||||
_, ok := this.browserMap[browserKey]
|
||||
if !ok {
|
||||
if this.serverBrowserCountMap[serverId] < 256 { // 限制最大数据,防止攻击
|
||||
this.serverBrowserCountMap[serverId]++
|
||||
if this.serverBrowserCountMap[serverIdString] < 256 { // 限制最大数据,防止攻击
|
||||
this.serverBrowserCountMap[serverIdString]++
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
@@ -374,7 +381,7 @@ func (this *HTTPRequestStatManager) Upload() error {
|
||||
sort.Slice(pbCities, func(i, j int) bool {
|
||||
return pbCities[i].CountRequests > pbCities[j].CountRequests
|
||||
})
|
||||
var serverCountMap = map[int64]int16{}
|
||||
var serverCountMap = map[int64]int16{} // serverId => count
|
||||
for _, city := range pbCities {
|
||||
serverCountMap[city.ServerId]++
|
||||
if serverCountMap[city.ServerId] > maxCities {
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
var SharedTrafficStatManager = NewTrafficStatManager()
|
||||
|
||||
type TrafficItem struct {
|
||||
UserId int64
|
||||
Bytes int64
|
||||
CachedBytes int64
|
||||
CountRequests int64
|
||||
@@ -44,8 +46,8 @@ const trafficStatsMaxLife = 1200 // 最大只保存20分钟内的数据
|
||||
|
||||
// TrafficStatManager 区域流量统计
|
||||
type TrafficStatManager struct {
|
||||
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
|
||||
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
|
||||
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
|
||||
domainsMap map[int64]map[string]*TrafficItem // serverIde => { timestamp @ domain => *TrafficItem }
|
||||
|
||||
pbItems []*pb.ServerDailyStat
|
||||
pbDomainItems []*pb.UploadServerDailyStatsRequest_DomainStat
|
||||
@@ -59,7 +61,7 @@ type TrafficStatManager struct {
|
||||
func NewTrafficStatManager() *TrafficStatManager {
|
||||
var manager = &TrafficStatManager{
|
||||
itemMap: map[string]*TrafficItem{},
|
||||
domainsMap: map[string]*TrafficItem{},
|
||||
domainsMap: map[int64]map[string]*TrafficItem{},
|
||||
}
|
||||
|
||||
return manager
|
||||
@@ -106,25 +108,30 @@ func (this *TrafficStatManager) Start() {
|
||||
}
|
||||
|
||||
// Add 添加流量
|
||||
func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
|
||||
func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) {
|
||||
if serverId == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加到带宽
|
||||
SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes)
|
||||
|
||||
if bytes == 0 && countRequests == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
this.totalRequests++
|
||||
|
||||
var timestamp = utils.FloorUnixTime(300)
|
||||
var timestamp = fasttime.Now().UnixFloor(300)
|
||||
var key = strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
|
||||
this.locker.Lock()
|
||||
|
||||
// 总的流量
|
||||
item, ok := this.itemMap[key]
|
||||
if !ok {
|
||||
item = &TrafficItem{}
|
||||
item = &TrafficItem{
|
||||
UserId: userId,
|
||||
}
|
||||
this.itemMap[key] = item
|
||||
}
|
||||
item.Bytes += bytes
|
||||
@@ -137,11 +144,17 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
|
||||
item.PlanId = planId
|
||||
|
||||
// 单个域名流量
|
||||
var domainKey = strconv.FormatInt(timestamp, 10) + "@" + strconv.FormatInt(serverId, 10) + "@" + domain
|
||||
domainItem, ok := this.domainsMap[domainKey]
|
||||
var domainKey = types.String(timestamp) + "@" + domain
|
||||
serverDomainMap, ok := this.domainsMap[serverId]
|
||||
if !ok {
|
||||
serverDomainMap = map[string]*TrafficItem{}
|
||||
this.domainsMap[serverId] = serverDomainMap
|
||||
}
|
||||
|
||||
domainItem, ok := serverDomainMap[domainKey]
|
||||
if !ok {
|
||||
domainItem = &TrafficItem{}
|
||||
this.domainsMap[domainKey] = domainItem
|
||||
serverDomainMap[domainKey] = domainItem
|
||||
}
|
||||
domainItem.Bytes += bytes
|
||||
domainItem.CachedBytes += cachedBytes
|
||||
@@ -173,7 +186,7 @@ func (this *TrafficStatManager) Upload() error {
|
||||
|
||||
// reset
|
||||
this.itemMap = map[string]*TrafficItem{}
|
||||
this.domainsMap = map[string]*TrafficItem{}
|
||||
this.domainsMap = map[int64]map[string]*TrafficItem{}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
@@ -190,6 +203,7 @@ func (this *TrafficStatManager) Upload() error {
|
||||
}
|
||||
|
||||
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
|
||||
UserId: item.UserId,
|
||||
ServerId: serverId,
|
||||
NodeRegionId: regionId,
|
||||
Bytes: item.Bytes,
|
||||
@@ -205,23 +219,43 @@ func (this *TrafficStatManager) Upload() error {
|
||||
}
|
||||
|
||||
// 域名统计
|
||||
const maxDomainsPerServer = 20
|
||||
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
|
||||
for key, item := range domainMap {
|
||||
var pieces = strings.SplitN(key, "@", 3)
|
||||
if len(pieces) != 3 {
|
||||
continue
|
||||
for serverId, serverDomainMap := range domainMap {
|
||||
// 如果超过单个服务最大值,则只取前N个
|
||||
var shouldTrim = len(serverDomainMap) > maxDomainsPerServer
|
||||
var tempItems []*pb.UploadServerDailyStatsRequest_DomainStat
|
||||
|
||||
for key, item := range serverDomainMap {
|
||||
var pieces = strings.SplitN(key, "@", 2)
|
||||
if len(pieces) != 2 {
|
||||
continue
|
||||
}
|
||||
var pbItem = &pb.UploadServerDailyStatsRequest_DomainStat{
|
||||
ServerId: serverId,
|
||||
Domain: pieces[1],
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CreatedAt: types.Int64(pieces[0]),
|
||||
}
|
||||
if !shouldTrim {
|
||||
pbDomainStats = append(pbDomainStats, pbItem)
|
||||
} else {
|
||||
tempItems = append(tempItems, pbItem)
|
||||
}
|
||||
}
|
||||
|
||||
if shouldTrim {
|
||||
sort.Slice(tempItems, func(i, j int) bool {
|
||||
return tempItems[i].CountRequests > tempItems[j].CountRequests
|
||||
})
|
||||
|
||||
pbDomainStats = append(pbDomainStats, tempItems[:maxDomainsPerServer]...)
|
||||
}
|
||||
pbDomainStats = append(pbDomainStats, &pb.UploadServerDailyStatsRequest_DomainStat{
|
||||
ServerId: types.Int64(pieces[1]),
|
||||
Domain: pieces[2],
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CreatedAt: types.Int64(pieces[0]),
|
||||
})
|
||||
}
|
||||
|
||||
// 历史未提交记录
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
@@ -8,7 +11,7 @@ import (
|
||||
func TestTrafficStatManager_Add(t *testing.T) {
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < 100; i++ {
|
||||
manager.Add(1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
|
||||
manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
|
||||
}
|
||||
t.Log(manager.itemMap)
|
||||
}
|
||||
@@ -16,7 +19,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
|
||||
func TestTrafficStatManager_Upload(t *testing.T) {
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < 100; i++ {
|
||||
manager.Add(1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0)
|
||||
manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, false, 0)
|
||||
}
|
||||
err := manager.Upload()
|
||||
if err != nil {
|
||||
@@ -28,8 +31,12 @@ func TestTrafficStatManager_Upload(t *testing.T) {
|
||||
func BenchmarkTrafficStatManager_Add(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Add(1, "goedge.cn", 1024, 1, 0, 0, 0, 0, false, 0)
|
||||
}
|
||||
var manager = NewTrafficStatManager()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, false, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package ttlcache
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"time"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
)
|
||||
|
||||
var SharedCache = NewCache()
|
||||
@@ -10,8 +10,10 @@ var SharedCache = NewCache()
|
||||
// Cache TTL缓存
|
||||
// 最大的缓存时间为30 * 86400
|
||||
// Piece数据结构:
|
||||
// Piece1 | Piece2 | Piece3 | ...
|
||||
// [ Item1, Item2, ... ] | ...
|
||||
//
|
||||
// Piece1 | Piece2 | Piece3 | ...
|
||||
// [ Item1, Item2, ... ] | ...
|
||||
//
|
||||
// KeyMap列表数据结构
|
||||
// { timestamp1 => [key1, key2, ...] }, ...
|
||||
type Cache struct {
|
||||
@@ -69,12 +71,12 @@ func NewCache(opt ...OptionInterface) *Cache {
|
||||
return cache
|
||||
}
|
||||
|
||||
func (this *Cache) Write(key string, value interface{}, expiredAt int64) (ok bool) {
|
||||
func (this *Cache) Write(key string, value any, expiredAt int64) (ok bool) {
|
||||
if this.isDestroyed {
|
||||
return
|
||||
}
|
||||
|
||||
var currentTimestamp = utils.UnixTime()
|
||||
var currentTimestamp = fasttime.Now().Unix()
|
||||
if expiredAt <= currentTimestamp {
|
||||
return
|
||||
}
|
||||
@@ -83,8 +85,8 @@ func (this *Cache) Write(key string, value interface{}, expiredAt int64) (ok boo
|
||||
if expiredAt > maxExpiredAt {
|
||||
expiredAt = maxExpiredAt
|
||||
}
|
||||
uint64Key := HashKey([]byte(key))
|
||||
pieceIndex := uint64Key % this.countPieces
|
||||
var uint64Key = HashKey([]byte(key))
|
||||
var pieceIndex = uint64Key % this.countPieces
|
||||
return this.pieces[pieceIndex].Add(uint64Key, &Item{
|
||||
Value: value,
|
||||
expiredAt: expiredAt,
|
||||
@@ -96,22 +98,22 @@ func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64, exten
|
||||
return 0
|
||||
}
|
||||
|
||||
currentTimestamp := time.Now().Unix()
|
||||
var currentTimestamp = fasttime.Now().Unix()
|
||||
if expiredAt <= currentTimestamp {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxExpiredAt := currentTimestamp + 30*86400
|
||||
var maxExpiredAt = currentTimestamp + 30*86400
|
||||
if expiredAt > maxExpiredAt {
|
||||
expiredAt = maxExpiredAt
|
||||
}
|
||||
uint64Key := HashKey([]byte(key))
|
||||
pieceIndex := uint64Key % this.countPieces
|
||||
var uint64Key = HashKey([]byte(key))
|
||||
var pieceIndex = uint64Key % this.countPieces
|
||||
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt, extend)
|
||||
}
|
||||
|
||||
func (this *Cache) Read(key string) (item *Item) {
|
||||
uint64Key := HashKey([]byte(key))
|
||||
var uint64Key = HashKey([]byte(key))
|
||||
return this.pieces[uint64Key%this.countPieces].Read(uint64Key)
|
||||
}
|
||||
|
||||
@@ -120,7 +122,7 @@ func (this *Cache) readIntKey(key uint64) (value *Item) {
|
||||
}
|
||||
|
||||
func (this *Cache) Delete(key string) {
|
||||
uint64Key := HashKey([]byte(key))
|
||||
var uint64Key = HashKey([]byte(key))
|
||||
this.pieces[uint64Key%this.countPieces].Delete(uint64Key)
|
||||
}
|
||||
|
||||
@@ -137,7 +139,7 @@ func (this *Cache) Count() (count int) {
|
||||
|
||||
func (this *Cache) GC() {
|
||||
this.pieces[this.gcPieceIndex].GC()
|
||||
newIndex := this.gcPieceIndex + 1
|
||||
var newIndex = this.gcPieceIndex + 1
|
||||
if newIndex >= int(this.countPieces) {
|
||||
newIndex = 0
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package ttlcache
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
@@ -195,7 +195,7 @@ func BenchmarkCache_Add(b *testing.B) {
|
||||
|
||||
var cache = NewCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Write(strconv.Itoa(i), i, utils.UnixTime()+int64(i%1024))
|
||||
cache.Write(strconv.Itoa(i), i, fasttime.Now().Unix()+int64(i%1024))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func BenchmarkCache_Add_Parallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var j = atomic.AddInt64(&i, 1)
|
||||
cache.Write(types.String(j), j, utils.UnixTime()+i%1024)
|
||||
cache.Write(types.String(j), j, fasttime.Now().Unix()+i%1024)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package ttlcache
|
||||
|
||||
type Item struct {
|
||||
Value interface{}
|
||||
Value any
|
||||
expiredAt int64
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package ttlcache
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Piece struct {
|
||||
@@ -27,7 +26,7 @@ func NewPiece(maxItems int) *Piece {
|
||||
|
||||
func (this *Piece) Add(key uint64, item *Item) (ok bool) {
|
||||
this.locker.Lock()
|
||||
if len(this.m) >= this.maxItems {
|
||||
if this.maxItems > 0 && len(this.m) >= this.maxItems {
|
||||
this.locker.Unlock()
|
||||
return
|
||||
}
|
||||
@@ -42,7 +41,7 @@ func (this *Piece) Add(key uint64, item *Item) (ok bool) {
|
||||
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64, extend bool) (result int64) {
|
||||
this.locker.Lock()
|
||||
item, ok := this.m[key]
|
||||
if ok && item.expiredAt > time.Now().Unix() {
|
||||
if ok && item.expiredAt > fasttime.Now().Unix() {
|
||||
result = types.Int64(item.Value) + delta
|
||||
item.Value = result
|
||||
if extend {
|
||||
@@ -75,7 +74,7 @@ func (this *Piece) Delete(key uint64) {
|
||||
func (this *Piece) Read(key uint64) (item *Item) {
|
||||
this.locker.RLock()
|
||||
item = this.m[key]
|
||||
if item != nil && item.expiredAt < utils.UnixTime() {
|
||||
if item != nil && item.expiredAt < fasttime.Now().Unix() {
|
||||
item = nil
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
@@ -91,7 +90,7 @@ func (this *Piece) Count() (count int) {
|
||||
}
|
||||
|
||||
func (this *Piece) GC() {
|
||||
var currentTime = time.Now().Unix()
|
||||
var currentTime = fasttime.Now().Unix()
|
||||
if this.lastGCTime == 0 {
|
||||
this.lastGCTime = currentTime - 3600
|
||||
}
|
||||
|
||||
72
internal/utils/conns/conn_no_stat.go
Normal file
72
internal/utils/conns/conn_no_stat.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package connutils
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 记录不需要带宽统计的连接
|
||||
// 比如本地的清理和预热
|
||||
var noStatAddrMap = map[string]zero.Zero{} // addr => Zero
|
||||
var noStatLocker = &sync.RWMutex{}
|
||||
|
||||
// IsNoStatConn 检查是否为不统计连接
|
||||
func IsNoStatConn(addr string) bool {
|
||||
noStatLocker.RLock()
|
||||
_, ok := noStatAddrMap[addr]
|
||||
noStatLocker.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
type NoStatConn struct {
|
||||
rawConn net.Conn
|
||||
}
|
||||
|
||||
func NewNoStat(rawConn net.Conn) net.Conn {
|
||||
noStatLocker.Lock()
|
||||
noStatAddrMap[rawConn.LocalAddr().String()] = zero.New()
|
||||
noStatLocker.Unlock()
|
||||
return &NoStatConn{rawConn: rawConn}
|
||||
}
|
||||
|
||||
func (this *NoStatConn) Read(b []byte) (n int, err error) {
|
||||
return this.rawConn.Read(b)
|
||||
}
|
||||
|
||||
func (this *NoStatConn) Write(b []byte) (n int, err error) {
|
||||
return this.rawConn.Write(b)
|
||||
}
|
||||
|
||||
func (this *NoStatConn) Close() error {
|
||||
err := this.rawConn.Close()
|
||||
|
||||
noStatLocker.Lock()
|
||||
delete(noStatAddrMap, this.rawConn.LocalAddr().String())
|
||||
noStatLocker.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *NoStatConn) LocalAddr() net.Addr {
|
||||
return this.rawConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (this *NoStatConn) RemoteAddr() net.Addr {
|
||||
return this.rawConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (this *NoStatConn) SetDeadline(t time.Time) error {
|
||||
return this.rawConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (this *NoStatConn) SetReadDeadline(t time.Time) error {
|
||||
return this.rawConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (this *NoStatConn) SetWriteDeadline(t time.Time) error {
|
||||
return this.rawConn.SetWriteDeadline(t)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package expires
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
@@ -197,7 +197,7 @@ func BenchmarkList_GC(b *testing.B) {
|
||||
for m := 0; m < 1_000; m++ {
|
||||
var list = NewList()
|
||||
for j := 0; j < 10_000; j++ {
|
||||
list.Add(uint64(j), utils.UnixTime()+100)
|
||||
list.Add(uint64(j), fasttime.Now().Unix()+100)
|
||||
}
|
||||
lists = append(lists, list)
|
||||
}
|
||||
|
||||
93
internal/utils/fasttime/time_fast.go
Normal file
93
internal/utils/fasttime/time_fast.go
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package fasttime
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedFastTime = NewFastTime()
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
var ticker = time.NewTicker(200 * time.Millisecond)
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
sharedFastTime = NewFastTime()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Now() *FastTime {
|
||||
return sharedFastTime
|
||||
}
|
||||
|
||||
type FastTime struct {
|
||||
rawTime time.Time
|
||||
unixTime int64
|
||||
unixTimeMilli int64
|
||||
unixTimeMilliString string
|
||||
ymd string
|
||||
round5Hi string
|
||||
}
|
||||
|
||||
func NewFastTime() *FastTime {
|
||||
var rawTime = time.Now()
|
||||
|
||||
return &FastTime{
|
||||
rawTime: rawTime,
|
||||
unixTime: rawTime.Unix(),
|
||||
unixTimeMilli: rawTime.UnixMilli(),
|
||||
unixTimeMilliString: types.String(rawTime.UnixMilli()),
|
||||
ymd: timeutil.Format("Ymd", rawTime),
|
||||
round5Hi: timeutil.FormatTime("Hi", rawTime.Unix()/300*300),
|
||||
}
|
||||
}
|
||||
|
||||
// Unix 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景
|
||||
func (this *FastTime) Unix() int64 {
|
||||
return this.unixTime
|
||||
}
|
||||
|
||||
// UnixFloor 取整
|
||||
func (this *FastTime) UnixFloor(seconds int) int64 {
|
||||
return this.unixTime / int64(seconds) * int64(seconds)
|
||||
}
|
||||
|
||||
// UnixCell 取整并加1
|
||||
func (this *FastTime) UnixCell(seconds int) int64 {
|
||||
return this.unixTime/int64(seconds)*int64(seconds) + int64(seconds)
|
||||
}
|
||||
|
||||
// UnixNextMinute 获取下一分钟开始的时间戳
|
||||
func (this *FastTime) UnixNextMinute() int64 {
|
||||
return this.UnixCell(60)
|
||||
}
|
||||
|
||||
// UnixMilli 获取时间戳,精确到毫秒
|
||||
func (this *FastTime) UnixMilli() int64 {
|
||||
return this.unixTimeMilli
|
||||
}
|
||||
|
||||
func (this *FastTime) UnixMilliString() (int64, string) {
|
||||
return this.unixTimeMilli, this.unixTimeMilliString
|
||||
}
|
||||
|
||||
func (this *FastTime) Ymd() string {
|
||||
return this.ymd
|
||||
}
|
||||
|
||||
func (this *FastTime) Round5Hi() string {
|
||||
return this.round5Hi
|
||||
}
|
||||
|
||||
func (this *FastTime) Format(layout string) string {
|
||||
return timeutil.Format(layout, this.rawTime)
|
||||
}
|
||||
57
internal/utils/fasttime/time_fast_test.go
Normal file
57
internal/utils/fasttime/time_fast_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package fasttime_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFastTime_Unix(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
var now = fasttime.Now()
|
||||
t.Log(now.Unix(), now.UnixMilli(), "real:", time.Now().Unix())
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFastTime_UnixMilli(t *testing.T) {
|
||||
t.Log(fasttime.Now().UnixMilliString())
|
||||
}
|
||||
|
||||
func TestFastTime_UnixFloor(t *testing.T) {
|
||||
var now = fasttime.Now()
|
||||
|
||||
var timestamp = time.Now().Unix()
|
||||
t.Log("floor 60:", timestamp, now.UnixFloor(60), timeutil.FormatTime("Y-m-d H:i:s", now.UnixFloor(60)))
|
||||
t.Log("ceil 60:", timestamp, now.UnixCell(60), timeutil.FormatTime("Y-m-d H:i:s", now.UnixCell(60)))
|
||||
t.Log("floor 300:", timestamp, now.UnixFloor(300), timeutil.FormatTime("Y-m-d H:i:s", now.UnixFloor(300)))
|
||||
t.Log("next minute:", now.UnixNextMinute(), timeutil.FormatTime("Y-m-d H:i:s", now.UnixNextMinute()))
|
||||
t.Log("day:", now.Ymd())
|
||||
t.Log("round 5 minute:", now.Round5Hi())
|
||||
}
|
||||
|
||||
func TestFastTime_Format(t *testing.T) {
|
||||
var now = fasttime.Now()
|
||||
t.Log(now.Format("Y-m-d H:i:s"))
|
||||
}
|
||||
|
||||
func BenchmarkNewFastTime(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var now = fasttime.Now()
|
||||
_ = now.Ymd()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkNewFastTime_Raw(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var now = time.Now()
|
||||
_ = timeutil.Format("Ymd", now)
|
||||
}
|
||||
})
|
||||
}
|
||||
78
internal/utils/maps/map_fixed.go
Normal file
78
internal/utils/maps/map_fixed.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package maputils
|
||||
|
||||
import "sync"
|
||||
|
||||
type KeyType interface {
|
||||
string | int | int64 | int32 | uint64 | uint32
|
||||
}
|
||||
|
||||
type ValueType interface {
|
||||
any
|
||||
}
|
||||
|
||||
// FixedMap
|
||||
// TODO 解决已存在元素不能按顺序弹出的问题
|
||||
type FixedMap[KeyT KeyType, ValueT ValueType] struct {
|
||||
m map[KeyT]ValueT
|
||||
keys []KeyT
|
||||
|
||||
maxSize int
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewFixedMap[KeyT KeyType, ValueT ValueType](maxSize int) *FixedMap[KeyT, ValueT] {
|
||||
return &FixedMap[KeyT, ValueT]{
|
||||
maxSize: maxSize,
|
||||
m: map[KeyT]ValueT{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FixedMap[KeyT, ValueT]) Put(key KeyT, value ValueT) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if this.maxSize <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
_, exists := this.m[key]
|
||||
this.m[key] = value
|
||||
|
||||
if !exists {
|
||||
this.keys = append(this.keys, key)
|
||||
|
||||
if len(this.keys) > this.maxSize {
|
||||
var firstKey = this.keys[0]
|
||||
this.keys = this.keys[1:]
|
||||
delete(this.m, firstKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FixedMap[KeyT, ValueT]) Get(key KeyT) (value ValueT, ok bool) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
value, ok = this.m[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (this *FixedMap[KeyT, ValueT]) Has(key KeyT) bool {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.m[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (this *FixedMap[KeyT, ValueT]) Keys() []KeyT {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
return this.keys
|
||||
}
|
||||
|
||||
func (this *FixedMap[KeyT, ValueT]) RawMap() map[KeyT]ValueT {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
return this.m
|
||||
}
|
||||
26
internal/utils/maps/map_fixed_test.go
Normal file
26
internal/utils/maps/map_fixed_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package maputils_test
|
||||
|
||||
import (
|
||||
maputils "github.com/TeaOSLab/EdgeNode/internal/utils/maps"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFixedMap(t *testing.T) {
|
||||
var m = maputils.NewFixedMap[string, int](3)
|
||||
m.Put("a", 1)
|
||||
t.Log(m.RawMap())
|
||||
t.Log(m.Get("a"))
|
||||
t.Log(m.Get("b"))
|
||||
|
||||
m.Put("b", 2)
|
||||
m.Put("c", 3)
|
||||
t.Log(m.RawMap(), m.Keys())
|
||||
|
||||
m.Put("d", 4)
|
||||
t.Log(m.RawMap(), m.Keys())
|
||||
|
||||
m.Put("b", 200)
|
||||
t.Log(m.RawMap(), m.Keys())
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FixedSet
|
||||
// TODO 解决已存在元素不能按顺序弹出的问题
|
||||
type FixedSet struct {
|
||||
maxSize int
|
||||
locker sync.RWMutex
|
||||
|
||||
@@ -1,60 +1,9 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
var unixTime = time.Now().Unix()
|
||||
var unixTimeMilli = time.Now().UnixMilli()
|
||||
var unixTimeMilliString = types.String(unixTimeMilli)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
var ticker = time.NewTicker(200 * time.Millisecond)
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
unixTime = time.Now().Unix()
|
||||
unixTimeMilli = time.Now().UnixMilli()
|
||||
unixTimeMilliString = types.String(unixTimeMilli)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// UnixTime 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景
|
||||
func UnixTime() int64 {
|
||||
return unixTime
|
||||
}
|
||||
|
||||
// FloorUnixTime 取整
|
||||
func FloorUnixTime(seconds int) int64 {
|
||||
return UnixTime() / int64(seconds) * int64(seconds)
|
||||
}
|
||||
|
||||
// CeilUnixTime 取整并加1
|
||||
func CeilUnixTime(seconds int) int64 {
|
||||
return UnixTime()/int64(seconds)*int64(seconds) + int64(seconds)
|
||||
}
|
||||
|
||||
// NextMinuteUnixTime 获取下一分钟开始的时间戳
|
||||
func NextMinuteUnixTime() int64 {
|
||||
return CeilUnixTime(60)
|
||||
}
|
||||
|
||||
// UnixTimeMilli 获取时间戳,精确到毫秒
|
||||
func UnixTimeMilli() int64 {
|
||||
return unixTimeMilli
|
||||
}
|
||||
|
||||
func UnixTimeMilliString() (int64, string) {
|
||||
return unixTimeMilli, unixTimeMilliString
|
||||
}
|
||||
|
||||
// GMTUnixTime 计算GMT时间戳
|
||||
func GMTUnixTime(timestamp int64) int64 {
|
||||
_, offset := time.Now().Zone()
|
||||
|
||||
@@ -1,30 +1,15 @@
|
||||
package utils
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUnixTime(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
t.Log(UnixTime(), "real:", time.Now().Unix())
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGMTUnixTime(t *testing.T) {
|
||||
t.Log(GMTUnixTime(time.Now().Unix()))
|
||||
t.Log(utils.GMTUnixTime(time.Now().Unix()))
|
||||
}
|
||||
|
||||
func TestGMTTime(t *testing.T) {
|
||||
t.Log(GMTTime(time.Now()))
|
||||
}
|
||||
|
||||
func TestFloorUnixTime(t *testing.T) {
|
||||
var timestamp = time.Now().Unix()
|
||||
t.Log("floor 60:", timestamp, FloorUnixTime(60), timeutil.FormatTime("Y-m-d H:i:s", FloorUnixTime(60)))
|
||||
t.Log("ceil 60:", timestamp, CeilUnixTime(60), timeutil.FormatTime("Y-m-d H:i:s", CeilUnixTime(60)))
|
||||
t.Log("floor 300:", timestamp, FloorUnixTime(300), timeutil.FormatTime("Y-m-d H:i:s", FloorUnixTime(300)))
|
||||
t.Log("next minute:", NextMinuteUnixTime())
|
||||
t.Log(utils.GMTTime(time.Now()))
|
||||
}
|
||||
|
||||
@@ -119,13 +119,7 @@ func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64,
|
||||
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64(key, 1, time.Now().Unix()+300, true)
|
||||
if int(countFails) >= maxFails {
|
||||
var useLocalFirewall = false
|
||||
|
||||
if this.Scope == firewallconfigs.FirewallScopeGlobal {
|
||||
useLocalFirewall = true
|
||||
}
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ type recordIPTask struct {
|
||||
sourceHTTPFirewallRuleSetId int64
|
||||
}
|
||||
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 1024)
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 2048)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
@@ -45,32 +45,60 @@ func init() {
|
||||
return
|
||||
}
|
||||
|
||||
for task := range recordIPTaskChan {
|
||||
ipType := "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
var reason = task.reason
|
||||
if len(reason) == 0 {
|
||||
reason = "触发WAF规则自动加入"
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC.CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiresAt,
|
||||
Reason: reason,
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
|
||||
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
const maxItems = 512 // 每次上传的最大IP数
|
||||
|
||||
for {
|
||||
var pbItemMap = map[string]*pb.CreateIPItemsRequest_IPItem{} // ip => IPItem
|
||||
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case task := <-recordIPTaskChan:
|
||||
var ipType = "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
var reason = task.reason
|
||||
if len(reason) == 0 {
|
||||
reason = "触发WAF规则自动加入"
|
||||
}
|
||||
|
||||
pbItemMap[task.ip] = &pb.CreateIPItemsRequest_IPItem{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiresAt,
|
||||
Reason: reason,
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
|
||||
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
|
||||
}
|
||||
|
||||
if len(pbItemMap) >= maxItems {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if len(pbItemMap) > 0 {
|
||||
var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
|
||||
for _, pbItem := range pbItemMap {
|
||||
pbItems = append(pbItems, pbItem)
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -29,13 +29,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol
|
||||
}
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true)
|
||||
if int(countFails) >= maxFails {
|
||||
var useLocalFirewall = false
|
||||
|
||||
if actionConfig.FailBlockScopeAll {
|
||||
useLocalFirewall = true
|
||||
}
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedIPWhiteList = NewIPList(IPListTypeAllow)
|
||||
@@ -34,6 +34,9 @@ type IPList struct {
|
||||
|
||||
id uint64
|
||||
locker sync.RWMutex
|
||||
|
||||
lastIP string // 加入到 recordIPTaskChan 之前尽可能去重
|
||||
lastTime int64
|
||||
}
|
||||
|
||||
// NewIPList 获取新对象
|
||||
@@ -95,33 +98,35 @@ func (this *IPList) RecordIP(ipType string,
|
||||
this.Add(ipType, scope, serverId, ip, expiresAt)
|
||||
|
||||
if this.listType == IPListTypeDeny {
|
||||
// 加入队列等待上传
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: ip,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiresAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
serverId: serverId,
|
||||
sourceServerId: serverId,
|
||||
sourceHTTPFirewallPolicyId: policyId,
|
||||
sourceHTTPFirewallRuleGroupId: groupId,
|
||||
sourceHTTPFirewallRuleSetId: setId,
|
||||
reason: reason,
|
||||
}:
|
||||
default:
|
||||
|
||||
// 作用域
|
||||
var scopeServerId int64
|
||||
if scope == firewallconfigs.FirewallScopeService {
|
||||
scopeServerId = serverId
|
||||
}
|
||||
|
||||
// 使用本地防火墙
|
||||
if useLocalFirewall {
|
||||
var seconds = expiresAt - time.Now().Unix()
|
||||
if seconds > 0 {
|
||||
// 最大3600,防止误封时间过长
|
||||
if seconds > 3600 {
|
||||
seconds = 3600
|
||||
}
|
||||
_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
|
||||
// 加入队列等待上传
|
||||
if this.lastIP != ip || fasttime.Now().Unix()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: ip,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiresAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
serverId: scopeServerId,
|
||||
sourceServerId: serverId,
|
||||
sourceHTTPFirewallPolicyId: policyId,
|
||||
sourceHTTPFirewallRuleGroupId: groupId,
|
||||
sourceHTTPFirewallRuleSetId: setId,
|
||||
reason: reason,
|
||||
}:
|
||||
this.lastIP = ip
|
||||
this.lastTime = fasttime.Now().Unix()
|
||||
default:
|
||||
}
|
||||
|
||||
// 使用本地防火墙
|
||||
if useLocalFirewall {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,4 +37,7 @@ type Request interface {
|
||||
|
||||
// Format 格式化变量
|
||||
Format(string) string
|
||||
|
||||
// DisableAccessLog 在当前请求中不使用访问日志
|
||||
DisableAccessLog()
|
||||
}
|
||||
|
||||
@@ -76,3 +76,11 @@ func (this *TestRequest) Format(s string) string {
|
||||
func (this *TestRequest) WAFOnAction(action interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFFingerprint() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TestRequest) DisableAccessLog() {
|
||||
|
||||
}
|
||||
|
||||
@@ -250,12 +250,14 @@ func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter)
|
||||
// validate captcha
|
||||
var rawPath = req.WAFRaw().URL.Path
|
||||
if rawPath == CaptchaPath {
|
||||
req.DisableAccessLog()
|
||||
captchaValidator.Run(req, writer)
|
||||
return
|
||||
}
|
||||
|
||||
// Get 302验证
|
||||
if rawPath == Get302Path {
|
||||
req.DisableAccessLog()
|
||||
get302Validator.Run(req, writer)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user