Compare commits
74 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa60092c20 | ||
|
|
54fc265d24 | ||
|
|
a5ac900784 | ||
|
|
4053f1da32 | ||
|
|
0374ccd8a8 | ||
|
|
1d46c446cf | ||
|
|
54b66805f9 | ||
|
|
f7afcbde92 | ||
|
|
8bec1cf68e | ||
|
|
2cd1bb7f95 | ||
|
|
19e6329a2b | ||
|
|
fce2879567 | ||
|
|
0973765919 | ||
|
|
827679721e | ||
|
|
735279bc7a | ||
|
|
3eb2ed9897 | ||
|
|
3a913d98c7 | ||
|
|
9bfcd79e36 | ||
|
|
a81d610302 | ||
|
|
64b1753c4d | ||
|
|
afcb5c2957 | ||
|
|
7d0b9208a3 | ||
|
|
c0f0ec43bb | ||
|
|
30bd66958c | ||
|
|
fafac1a038 | ||
|
|
f979f9503e | ||
|
|
b233c3cc7a | ||
|
|
597ac936f7 | ||
|
|
a590254eb3 | ||
|
|
0498bcf30c | ||
|
|
59f9b5c724 | ||
|
|
80729935b6 | ||
|
|
4ca57fb99c | ||
|
|
9b35902ad4 | ||
|
|
3b8bd09190 | ||
|
|
71a5bc0652 | ||
|
|
ac6a8c4e85 | ||
|
|
f58a808c3a | ||
|
|
51037be772 | ||
|
|
443ff9aff7 | ||
|
|
57cb00edf0 | ||
|
|
3fb39b479a | ||
|
|
4a1daff143 | ||
|
|
dd1dbd424e | ||
|
|
305cb4b46e | ||
|
|
ef90dce29b | ||
|
|
3cb69f4c71 | ||
|
|
af4cd05df2 | ||
|
|
64e0ae80b7 | ||
|
|
8bba228745 | ||
|
|
8cc06e6707 | ||
|
|
52fdee2eeb | ||
|
|
b5f52dd136 | ||
|
|
abda886de5 | ||
|
|
18f08525b9 | ||
|
|
b2a9a31fe5 | ||
|
|
f578114aeb | ||
|
|
bf4f47fc35 | ||
|
|
0be951742a | ||
|
|
59d3d6ae4b | ||
|
|
d061876f7e | ||
|
|
ddaec82415 | ||
|
|
8afd00f00d | ||
|
|
0b306f0a22 | ||
|
|
b36a36172b | ||
|
|
770278bbbc | ||
|
|
ca24818571 | ||
|
|
cd0af22655 | ||
|
|
96cb8d8af7 | ||
|
|
91face15bf | ||
|
|
6f52df63a5 | ||
|
|
509d81dc66 | ||
|
|
817f2a6f91 | ||
|
|
d74e10c7a8 |
@@ -52,7 +52,6 @@ function build() {
|
||||
cp "$ROOT"/configs/api.template.yaml "$DIST"/configs
|
||||
cp -R "$ROOT"/www "$DIST"/
|
||||
cp -R "$ROOT"/pages "$DIST"/
|
||||
cp -R "$ROOT"/resources "$DIST"/
|
||||
|
||||
# we support TOA on linux/amd64 only
|
||||
if [ "$OS" == "linux" -a "$ARCH" == "amd64" ]
|
||||
|
||||
Binary file not shown.
@@ -318,6 +318,21 @@ func main() {
|
||||
}
|
||||
}
|
||||
})
|
||||
app.On("bandwidth", func() {
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
reply, err := sock.Send(&gosock.Command{Code: "bandwidth"})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
var statsMap = maps.NewMap(reply.Params).Get("stats")
|
||||
statsJSON, err := json.MarshalIndent(statsMap, "", " ")
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Println(string(statsJSON))
|
||||
})
|
||||
app.Run(func() {
|
||||
var node = nodes.NewNode()
|
||||
node.Start()
|
||||
|
||||
@@ -41,7 +41,7 @@ func (this *LogWriter) Init() {
|
||||
this.c = make(chan string, 1024)
|
||||
|
||||
// 异步写入文件
|
||||
var maxFileSize = 2 * sizes.G // 文件最大尺寸,超出此尺寸则清空
|
||||
var maxFileSize = 128 * sizes.M // 文件最大尺寸,超出此尺寸则清空
|
||||
if fp != nil {
|
||||
goman.New(func() {
|
||||
var totalSize int64 = 0
|
||||
|
||||
@@ -96,7 +96,7 @@ func (this *FileList) Reset() error {
|
||||
}
|
||||
|
||||
func (this *FileList) Add(hash string, item *Item) error {
|
||||
var db = this.getDB(hash)
|
||||
var db = this.GetDB(hash)
|
||||
|
||||
if !db.IsReady() {
|
||||
return nil
|
||||
@@ -120,12 +120,17 @@ func (this *FileList) Add(hash string, item *Item) error {
|
||||
}
|
||||
|
||||
func (this *FileList) Exist(hash string) (bool, error) {
|
||||
var db = this.getDB(hash)
|
||||
var db = this.GetDB(hash)
|
||||
|
||||
if !db.IsReady() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 如果Hash列表里不存在,那么必然不存在
|
||||
if !db.hashMap.Exist(hash) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var item = this.memoryCache.Read(hash)
|
||||
if item != nil {
|
||||
return true, nil
|
||||
@@ -225,7 +230,7 @@ func (this *FileList) PurgeLFU(count int, callback func(hash string) error) erro
|
||||
return err
|
||||
}
|
||||
if notFound {
|
||||
_, err = db.deleteHitByHashStmt.Exec(hash)
|
||||
err = db.DeleteHitAsync(hash)
|
||||
if err != nil {
|
||||
return db.WrapError(err)
|
||||
}
|
||||
@@ -291,13 +296,13 @@ func (this *FileList) Count() (int64, error) {
|
||||
|
||||
// IncreaseHit 增加点击量
|
||||
func (this *FileList) IncreaseHit(hash string) error {
|
||||
var db = this.getDB(hash)
|
||||
var db = this.GetDB(hash)
|
||||
|
||||
if !db.IsReady() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return db.IncreaseHit(hash)
|
||||
return db.IncreaseHitAsync(hash)
|
||||
}
|
||||
|
||||
// OnAdd 添加事件
|
||||
@@ -326,17 +331,23 @@ func (this *FileList) GetDBIndex(hash string) uint64 {
|
||||
return fnv.HashString(hash) % CountFileDB
|
||||
}
|
||||
|
||||
func (this *FileList) getDB(hash string) *FileListDB {
|
||||
func (this *FileList) GetDB(hash string) *FileListDB {
|
||||
return this.dbList[fnv.HashString(hash)%CountFileDB]
|
||||
}
|
||||
|
||||
func (this *FileList) remove(hash string) (notFound bool, err error) {
|
||||
var db = this.getDB(hash)
|
||||
var db = this.GetDB(hash)
|
||||
|
||||
if !db.IsReady() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// HashMap中不存在,则确定不存在
|
||||
if !db.hashMap.Exist(hash) {
|
||||
return true, nil
|
||||
}
|
||||
defer db.hashMap.Delete(hash)
|
||||
|
||||
// 从缓存中删除
|
||||
this.memoryCache.Delete(hash)
|
||||
|
||||
@@ -364,7 +375,7 @@ func (this *FileList) remove(hash string) (notFound bool, err error) {
|
||||
|
||||
atomic.AddInt64(&this.total, -1)
|
||||
|
||||
_, err = db.deleteHitByHashStmt.Exec(hash)
|
||||
err = db.DeleteHitAsync(hash)
|
||||
if err != nil {
|
||||
return false, db.WrapError(err)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -25,6 +27,8 @@ type FileListDB struct {
|
||||
|
||||
writeBatch *dbs.Batch
|
||||
|
||||
hashMap *FileListHashMap
|
||||
|
||||
itemsTableName string
|
||||
hitsTableName string
|
||||
|
||||
@@ -41,44 +45,67 @@ type FileListDB struct {
|
||||
|
||||
selectByHashStmt *dbs.Stmt // 使用hash查询数据
|
||||
|
||||
selectHashListStmt *dbs.Stmt
|
||||
|
||||
deleteByHashStmt *dbs.Stmt // 根据hash删除数据
|
||||
deleteByHashSQL string
|
||||
|
||||
statStmt *dbs.Stmt // 统计
|
||||
purgeStmt *dbs.Stmt // 清理
|
||||
deleteAllStmt *dbs.Stmt // 删除所有数据
|
||||
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
|
||||
statStmt *dbs.Stmt // 统计
|
||||
purgeStmt *dbs.Stmt // 清理
|
||||
deleteAllStmt *dbs.Stmt // 删除所有数据
|
||||
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
|
||||
updateAccessWeekSQL string // 修改访问日期
|
||||
|
||||
// hits
|
||||
insertHitStmt *dbs.Stmt // 写入数据
|
||||
increaseHitStmt *dbs.Stmt // 增加点击量
|
||||
deleteHitByHashStmt *dbs.Stmt // 根据hash删除数据
|
||||
lfuHitsStmt *dbs.Stmt // 读取老的数据
|
||||
insertHitSQL string // 写入数据
|
||||
increaseHitSQL string // 增加点击量
|
||||
deleteHitByHashSQL string // 根据hash删除数据
|
||||
}
|
||||
|
||||
func NewFileListDB() *FileListDB {
|
||||
return &FileListDB{}
|
||||
return &FileListDB{
|
||||
hashMap: NewFileListHashMap(),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FileListDB) Open(dbPath string) error {
|
||||
this.dbPath = dbPath
|
||||
|
||||
// 动态调整Cache值
|
||||
var cacheSize = 32000
|
||||
var memoryGB = utils.SystemMemoryGB()
|
||||
if memoryGB >= 8 {
|
||||
cacheSize += 32000 * memoryGB / 8
|
||||
}
|
||||
|
||||
// write db
|
||||
writeDB, err := sql.Open("sqlite3", "file:"+dbPath+"?cache=private&mode=rwc&_journal_mode=WAL&_sync=OFF&_cache_size=32000&_secure_delete=FAST")
|
||||
writeDB, err := sql.Open("sqlite3", "file:"+dbPath+"?cache=private&mode=rwc&_journal_mode=WAL&_sync=OFF&_cache_size="+types.String(cacheSize)+"&_secure_delete=FAST")
|
||||
if err != nil {
|
||||
return errors.New("open write database failed: " + err.Error())
|
||||
}
|
||||
|
||||
writeDB.SetMaxOpenConns(1)
|
||||
|
||||
this.writeDB = dbs.NewDB(writeDB)
|
||||
|
||||
// TODO 耗时过长,暂时不整理数据库
|
||||
// TODO 需要根据行数来判断是否VACUUM
|
||||
// TODO 注意VACUUM反而可能让数据库文件变大
|
||||
/**_, err = db.Exec("VACUUM")
|
||||
if err != nil {
|
||||
return err
|
||||
}**/
|
||||
|
||||
this.writeDB = dbs.NewDB(writeDB)
|
||||
// 检查是否损坏
|
||||
// TODO 暂时屏蔽,因为用时过长
|
||||
|
||||
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
|
||||
if len(recoverEnv) > 0 && this.shouldRecover() {
|
||||
for _, indexName := range []string{"staleAt", "hash"} {
|
||||
_, _ = this.writeDB.Exec(`REINDEX "` + indexName + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
this.writeBatch = dbs.NewBatch(writeDB, 4)
|
||||
this.writeBatch.OnFail(func(err error) {
|
||||
remotelogs.Warn("LIST_FILE_DB", "run batch failed: "+err.Error())
|
||||
@@ -94,7 +121,7 @@ func (this *FileListDB) Open(dbPath string) error {
|
||||
}
|
||||
|
||||
// read db
|
||||
readDB, err := sql.Open("sqlite3", "file:"+dbPath+"?cache=private&mode=ro&_journal_mode=WAL&_sync=OFF&_cache_size=32000")
|
||||
readDB, err := sql.Open("sqlite3", "file:"+dbPath+"?cache=private&mode=ro&_journal_mode=WAL&_sync=OFF&_cache_size="+types.String(cacheSize))
|
||||
if err != nil {
|
||||
return errors.New("open read database failed: " + err.Error())
|
||||
}
|
||||
@@ -138,7 +165,7 @@ func (this *FileListDB) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt", "accessWeek") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
||||
this.insertStmt, err = this.writeDB.Prepare(this.insertSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -149,6 +176,8 @@ func (this *FileListDB) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
this.selectHashListStmt, err = this.readDB.Prepare(`SELECT "id", "hash" FROM "` + this.itemsTableName + `" WHERE id>:id ORDER BY id ASC LIMIT 2000`)
|
||||
|
||||
this.deleteByHashSQL = `DELETE FROM "` + this.itemsTableName + `" WHERE "hash"=?`
|
||||
this.deleteByHashStmt, err = this.writeDB.Prepare(this.deleteByHashSQL)
|
||||
if err != nil {
|
||||
@@ -170,27 +199,29 @@ func (this *FileListDB) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "id" ASC LIMIT ?`)
|
||||
|
||||
this.insertHitStmt, err = this.writeDB.Prepare(`INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?)`)
|
||||
|
||||
this.increaseHitStmt, err = this.writeDB.Prepare(`INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?) ON CONFLICT("hash") DO UPDATE SET "week1Hits"=IIF("week"=?, "week1Hits", "week2Hits"), "week2Hits"=IIF("week"=?, "week2Hits"+1, 1), "week"=?`)
|
||||
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "accessWeek" ASC, "id" ASC LIMIT ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteHitByHashStmt, err = this.writeDB.Prepare(`DELETE FROM "` + this.hitsTableName + `" WHERE "hash"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.updateAccessWeekSQL = `UPDATE "` + this.itemsTableName + `" SET "accessWeek"=? WHERE "hash"=?`
|
||||
|
||||
this.lfuHitsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.hitsTableName + `" ORDER BY "week" ASC, "week1Hits"+"week2Hits" ASC LIMIT ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.insertHitSQL = `INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?)`
|
||||
|
||||
this.increaseHitSQL = `INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?) ON CONFLICT("hash") DO UPDATE SET "week1Hits"=IIF("week"=?, "week1Hits", "week2Hits"), "week2Hits"=IIF("week"=?, "week2Hits"+1, 1), "week"=?`
|
||||
|
||||
this.deleteHitByHashSQL = `DELETE FROM "` + this.hitsTableName + `" WHERE "hash"=?`
|
||||
|
||||
this.isReady = true
|
||||
|
||||
// 加载HashMap
|
||||
go func() {
|
||||
err := this.hashMap.Load(this)
|
||||
if err != nil {
|
||||
remotelogs.Error("LIST_FILE_DB", "load hash map failed: "+err.Error()+"(file: "+this.dbPath+")")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,21 +234,25 @@ func (this *FileListDB) Total() int64 {
|
||||
}
|
||||
|
||||
func (this *FileListDB) AddAsync(hash string, item *Item) error {
|
||||
this.hashMap.Add(hash)
|
||||
|
||||
if item.StaleAt == 0 {
|
||||
item.StaleAt = item.ExpiredAt
|
||||
}
|
||||
|
||||
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime())
|
||||
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (this *FileListDB) AddSync(hash string, item *Item) error {
|
||||
this.hashMap.Add(hash)
|
||||
|
||||
if item.StaleAt == 0 {
|
||||
item.StaleAt = item.ExpiredAt
|
||||
}
|
||||
|
||||
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime())
|
||||
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
|
||||
if err != nil {
|
||||
return this.WrapError(err)
|
||||
}
|
||||
@@ -226,11 +261,15 @@ func (this *FileListDB) AddSync(hash string, item *Item) error {
|
||||
}
|
||||
|
||||
func (this *FileListDB) DeleteAsync(hash string) error {
|
||||
this.hashMap.Delete(hash)
|
||||
|
||||
this.writeBatch.Add(this.deleteByHashSQL, hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) DeleteSync(hash string) error {
|
||||
this.hashMap.Delete(hash)
|
||||
|
||||
_, err := this.deleteByHashStmt.Exec(hash)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -275,28 +314,56 @@ func (this *FileListDB) ListLFUItems(count int) (hashList []string, err error) {
|
||||
count = 100
|
||||
}
|
||||
|
||||
hashList, err = this.listLFUItems(count)
|
||||
// 先找过期的
|
||||
hashList, err = this.ListExpiredItems(count)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var l = len(hashList)
|
||||
|
||||
if len(hashList) > count/2 {
|
||||
return
|
||||
// 从旧缓存中补充
|
||||
if l < count {
|
||||
oldHashList, err := this.listOlderItems(count - l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashList = append(hashList, oldHashList...)
|
||||
}
|
||||
|
||||
// 不足补齐
|
||||
olderHashList, err := this.listOlderItems(count - len(hashList))
|
||||
return hashList, nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) ListHashes(lastId int64) (hashList []string, maxId int64, err error) {
|
||||
rows, err := this.selectHashListStmt.Query(lastId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
hashList = append(hashList, olderHashList...)
|
||||
var id int64
|
||||
var hash string
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&id, &hash)
|
||||
if err != nil {
|
||||
_ = rows.Close()
|
||||
return
|
||||
}
|
||||
maxId = id
|
||||
hashList = append(hashList, hash)
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
return
|
||||
}
|
||||
|
||||
func (this *FileListDB) IncreaseHit(hash string) error {
|
||||
func (this *FileListDB) IncreaseHitAsync(hash string) error {
|
||||
var week = timeutil.Format("YW")
|
||||
_, err := this.increaseHitStmt.Exec(hash, week, week, week, week)
|
||||
return this.WrapError(err)
|
||||
this.writeBatch.Add(this.increaseHitSQL, hash, week, week, week, week)
|
||||
this.writeBatch.Add(this.updateAccessWeekSQL, week, hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) DeleteHitAsync(hash string) error {
|
||||
this.writeBatch.Add(this.deleteHitByHashSQL, hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) CleanPrefix(prefix string) error {
|
||||
@@ -331,6 +398,8 @@ func (this *FileListDB) CleanAll() error {
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
this.hashMap.Clean()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -347,6 +416,9 @@ func (this *FileListDB) Close() error {
|
||||
if this.selectByHashStmt != nil {
|
||||
_ = this.selectByHashStmt.Close()
|
||||
}
|
||||
if this.selectHashListStmt != nil {
|
||||
_ = this.selectHashListStmt.Close()
|
||||
}
|
||||
if this.deleteByHashStmt != nil {
|
||||
_ = this.deleteByHashStmt.Close()
|
||||
}
|
||||
@@ -363,17 +435,8 @@ func (this *FileListDB) Close() error {
|
||||
_ = this.listOlderItemsStmt.Close()
|
||||
}
|
||||
|
||||
if this.insertHitStmt != nil {
|
||||
_ = this.insertHitStmt.Close()
|
||||
}
|
||||
if this.increaseHitStmt != nil {
|
||||
_ = this.increaseHitStmt.Close()
|
||||
}
|
||||
if this.deleteHitByHashStmt != nil {
|
||||
_ = this.deleteHitByHashStmt.Close()
|
||||
}
|
||||
if this.lfuHitsStmt != nil {
|
||||
_ = this.lfuHitsStmt.Close()
|
||||
if this.writeBatch != nil {
|
||||
this.writeBatch.Close()
|
||||
}
|
||||
|
||||
var errStrings []string
|
||||
@@ -392,11 +455,6 @@ func (this *FileListDB) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if this.writeBatch != nil {
|
||||
this.writeBatch.Close()
|
||||
}
|
||||
|
||||
if len(errStrings) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -427,7 +485,8 @@ func (this *FileListDB) initTables(times int) error {
|
||||
"staleAt" integer DEFAULT 0,
|
||||
"createdAt" integer DEFAULT 0,
|
||||
"host" varchar(128),
|
||||
"serverId" integer
|
||||
"serverId" integer,
|
||||
"accessWeek" varchar(6)
|
||||
);
|
||||
|
||||
DROP INDEX IF EXISTS "createdAt";
|
||||
@@ -443,19 +502,28 @@ CREATE UNIQUE INDEX IF NOT EXISTS "hash"
|
||||
ON "` + this.itemsTableName + `" (
|
||||
"hash" ASC
|
||||
);
|
||||
|
||||
ALTER TABLE "cacheItems" ADD "accessWeek" varchar(6);
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
// 尝试删除重建
|
||||
if times < 3 {
|
||||
_, dropErr := this.writeDB.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
|
||||
if dropErr == nil {
|
||||
return this.initTables(times + 1)
|
||||
}
|
||||
return this.WrapError(err)
|
||||
// 忽略可以预期的错误
|
||||
if strings.Contains(err.Error(), "duplicate column name") {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return this.WrapError(err)
|
||||
// 尝试删除重建
|
||||
if err != nil {
|
||||
if times < 3 {
|
||||
_, dropErr := this.writeDB.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
|
||||
if dropErr == nil {
|
||||
return this.initTables(times + 1)
|
||||
}
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
return this.WrapError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,27 +558,6 @@ ON "` + this.hitsTableName + `" (
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) listLFUItems(count int) (hashList []string, err error) {
|
||||
rows, err := this.lfuHitsStmt.Query(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var hash string
|
||||
err = rows.Scan(&hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashList = append(hashList, hash)
|
||||
}
|
||||
|
||||
return hashList, nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) listOlderItems(count int) (hashList []string, err error) {
|
||||
rows, err := this.listOlderItemsStmt.Query(count)
|
||||
if err != nil {
|
||||
@@ -531,3 +578,21 @@ func (this *FileListDB) listOlderItems(count int) (hashList []string, err error)
|
||||
|
||||
return hashList, nil
|
||||
}
|
||||
|
||||
func (this *FileListDB) shouldRecover() bool {
|
||||
result, err := this.writeDB.Query("pragma integrity_check;")
|
||||
if err != nil {
|
||||
logs.Println(result)
|
||||
}
|
||||
var errString = ""
|
||||
var shouldRecover = false
|
||||
for result.Next() {
|
||||
err = result.Scan(&errString)
|
||||
if strings.TrimSpace(errString) != "ok" {
|
||||
shouldRecover = true
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = result.Close()
|
||||
return shouldRecover
|
||||
}
|
||||
|
||||
49
internal/caches/list_file_db_test.go
Normal file
49
internal/caches/list_file_db_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package caches_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileListDB_ListLFUItems(t *testing.T) {
|
||||
var db = caches.NewFileListDB()
|
||||
err := db.Open(Tea.Root + "/data/cache-db-large.db")
|
||||
//err := db.Open(Tea.Root + "/data/cache-index/p1/db-0.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
hashList, err := db.ListLFUItems(100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("[", len(hashList), "]", hashList)
|
||||
}
|
||||
|
||||
func TestFileListDB_IncreaseHitAsync(t *testing.T) {
|
||||
var db = caches.NewFileListDB()
|
||||
err := db.Open(Tea.Root + "/data/cache-db-large.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Init()
|
||||
err = db.IncreaseHitAsync("4598e5231ba47d6ec7aa9ea640ff2eaf")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// wait transaction
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
120
internal/caches/list_file_hash_map.go
Normal file
120
internal/caches/list_file_hash_map.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package caches
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"math/big"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileListHashMap 文件Hash列表
|
||||
type FileListHashMap struct {
|
||||
m map[uint64]zero.Zero
|
||||
|
||||
locker sync.RWMutex
|
||||
isAvailable bool
|
||||
isReady bool
|
||||
}
|
||||
|
||||
func NewFileListHashMap() *FileListHashMap {
|
||||
return &FileListHashMap{
|
||||
m: map[uint64]zero.Zero{},
|
||||
isAvailable: false,
|
||||
isReady: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Load(db *FileListDB) error {
|
||||
// 如果系统内存过小,我们不缓存
|
||||
if utils.SystemMemoryGB() < 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.isAvailable = true
|
||||
|
||||
var lastId int64
|
||||
for {
|
||||
hashList, maxId, err := db.ListHashes(lastId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(hashList) == 0 {
|
||||
break
|
||||
}
|
||||
this.AddHashes(hashList)
|
||||
lastId = maxId
|
||||
}
|
||||
|
||||
this.isReady = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Add(hash string) {
|
||||
if !this.isAvailable {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
this.m[this.bigInt(hash)] = zero.New()
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) AddHashes(hashes []string) {
|
||||
if !this.isAvailable {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, hash := range hashes {
|
||||
this.m[this.bigInt(hash)] = zero.New()
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Delete(hash string) {
|
||||
if !this.isAvailable {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
delete(this.m, this.bigInt(hash))
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Exist(hash string) bool {
|
||||
if !this.isAvailable {
|
||||
return true
|
||||
}
|
||||
if !this.isReady {
|
||||
// 只有完全Ready时才能判断是否为false
|
||||
return true
|
||||
}
|
||||
this.locker.RLock()
|
||||
_, ok := this.m[this.bigInt(hash)]
|
||||
this.locker.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Clean() {
|
||||
this.locker.Lock()
|
||||
this.m = map[uint64]zero.Zero{}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) IsReady() bool {
|
||||
return this.isReady
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) Len() int {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
return len(this.m)
|
||||
}
|
||||
|
||||
func (this *FileListHashMap) bigInt(hash string) uint64 {
|
||||
var bigInt = big.NewInt(0)
|
||||
bigInt.SetString(hash, 16)
|
||||
return bigInt.Uint64()
|
||||
}
|
||||
96
internal/caches/list_file_hash_map_test.go
Normal file
96
internal/caches/list_file_hash_map_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package caches_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"math/big"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileListHashMap_Memory(t *testing.T) {
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
var m = caches.NewFileListHashMap()
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
m.Add(stringutil.Md5(types.String(i)))
|
||||
}
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
|
||||
t.Log("ready", (stat2.Alloc-stat1.Alloc)/1024/1024, "M")
|
||||
}
|
||||
|
||||
func TestFileListHashMap_Memory2(t *testing.T) {
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
var m = map[uint64]zero.Zero{}
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
m[uint64(i)] = zero.New()
|
||||
}
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
|
||||
t.Log("ready", (stat2.Alloc-stat1.Alloc)/1024/1024, "M")
|
||||
}
|
||||
|
||||
func TestFileListHashMap_BigInt(t *testing.T) {
|
||||
for _, s := range []string{"1", "2", "3", "123", "123456"} {
|
||||
var hash = stringutil.Md5(s)
|
||||
|
||||
var bigInt = big.NewInt(0)
|
||||
bigInt.SetString(hash, 16)
|
||||
t.Log(s, "=>", bigInt.Uint64(), "hash:", hash, "format:", strconv.FormatUint(bigInt.Uint64(), 16))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileListHashMap_Load(t *testing.T) {
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1").(*caches.FileList)
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = list.Close()
|
||||
}()
|
||||
|
||||
var m = caches.NewFileListHashMap()
|
||||
var before = time.Now()
|
||||
var db = list.GetDB("abc")
|
||||
err = m.Load(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
t.Log("count:", m.Len())
|
||||
m.Add("abc")
|
||||
|
||||
for _, hash := range []string{"33347bb4441265405347816cad36a0f8", "a", "abc", "123"} {
|
||||
t.Log(hash, "=>", m.Exist(hash))
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_BigInt(b *testing.B) {
|
||||
var hash = stringutil.Md5("123456")
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
var bigInt = big.NewInt(0)
|
||||
bigInt.SetString(hash, 16)
|
||||
_ = bigInt.Uint64()
|
||||
}
|
||||
}
|
||||
@@ -361,19 +361,6 @@ func TestFileList_UpgradeV3(t *testing.T) {
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestFileList_HashList(t *testing.T) {
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prefixes, err := list.(*caches.FileList).FindAllPrefixes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(prefixes))
|
||||
}
|
||||
|
||||
func BenchmarkFileList_Exist(b *testing.B) {
|
||||
var list = caches.NewFileList(Tea.Root + "/data/cache-index/p1")
|
||||
err := list.Init()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -16,7 +15,7 @@ var SharedManager = NewManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventQuit, func() {
|
||||
logs.Println("CACHE", "quiting cache manager")
|
||||
remotelogs.Println("CACHE", "quiting cache manager")
|
||||
SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ func (this *FileStorage) UpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy)
|
||||
// open cache
|
||||
oldOpenFileCacheJSON, _ := json.Marshal(oldOpenFileCache)
|
||||
newOpenFileCacheJSON, _ := json.Marshal(this.options.OpenFileCache)
|
||||
if bytes.Compare(oldOpenFileCacheJSON, newOpenFileCacheJSON) != 0 {
|
||||
if !bytes.Equal(oldOpenFileCacheJSON, newOpenFileCacheJSON) {
|
||||
this.initOpenFileCache()
|
||||
}
|
||||
|
||||
|
||||
@@ -520,8 +520,6 @@ func (this *MemoryStorage) flushItem(key string) {
|
||||
|
||||
// 从内存中移除
|
||||
_ = this.Delete(key)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *MemoryStorage) memoryCapacityBytes() int64 {
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package configs
|
||||
|
||||
import "sync"
|
||||
|
||||
var sharedLocker = &sync.RWMutex{}
|
||||
117
internal/conns/map.go
Normal file
117
internal/conns/map.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package conns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SharedMap = NewMap()
|
||||
|
||||
type Map struct {
|
||||
m map[string]map[int]net.Conn // ip => { port => Conn }
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMap() *Map {
|
||||
return &Map{
|
||||
m: map[string]map[int]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Map) Add(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var ip = tcpAddr.IP.String()
|
||||
var port = tcpAddr.Port
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
connMap, ok := this.m[ip]
|
||||
if !ok {
|
||||
this.m[ip] = map[int]net.Conn{
|
||||
port: conn,
|
||||
}
|
||||
} else {
|
||||
connMap[port] = conn
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Map) Remove(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var ip = tcpAddr.IP.String()
|
||||
var port = tcpAddr.Port
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
connMap, ok := this.m[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(connMap, port)
|
||||
|
||||
if len(connMap) == 0 {
|
||||
delete(this.m, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Map) CountIPConns(ip string) int {
|
||||
this.locker.RLock()
|
||||
var l = len(this.m[ip])
|
||||
this.locker.RUnlock()
|
||||
return l
|
||||
}
|
||||
|
||||
func (this *Map) CloseIPConns(ip string) {
|
||||
var conns = []net.Conn{}
|
||||
|
||||
this.locker.RLock()
|
||||
connMap, ok := this.m[ip]
|
||||
|
||||
// 复制,防止在Close时产生并发冲突
|
||||
if ok {
|
||||
for _, conn := range connMap {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// 需要在Close之前结束,防止死循环
|
||||
this.locker.RUnlock()
|
||||
|
||||
if ok {
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
// 这里不需要从 m 中删除,因为关闭时会自然触发回调
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Map) AllConns() []net.Conn {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
var result = []net.Conn{}
|
||||
for _, m := range this.m {
|
||||
for _, conn := range m {
|
||||
result = append(result, conn)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.5.2"
|
||||
Version = "0.5.4"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
@@ -53,19 +54,13 @@ func init() {
|
||||
|
||||
// DDoSProtectionManager DDoS防护
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
|
||||
lastAllowIPList []string
|
||||
lastConfig []byte
|
||||
}
|
||||
|
||||
// NewDDoSProtectionManager 获取新对象
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
nftPath, _ := exec.LookPath("nft")
|
||||
|
||||
return &DDoSProtectionManager{
|
||||
nftPath: nftPath,
|
||||
}
|
||||
return &DDoSProtectionManager{}
|
||||
}
|
||||
|
||||
// Apply 应用配置
|
||||
@@ -91,7 +86,7 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
|
||||
}
|
||||
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||
|
||||
if len(this.nftPath) == 0 {
|
||||
if len(this.nftExe()) == 0 {
|
||||
return errors.New("can not find nft command")
|
||||
}
|
||||
|
||||
@@ -154,6 +149,11 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
|
||||
|
||||
// 添加TCP规则
|
||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||
var nftExe = this.nftExe()
|
||||
if len(nftExe) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查nft版本不能小于0.9
|
||||
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
|
||||
return nil
|
||||
@@ -195,14 +195,31 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// new connections rate
|
||||
var newConnectionsRate = tcpConfig.NewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = 100000
|
||||
// new connections rate (minutely)
|
||||
var newConnectionsMinutelyRate = tcpConfig.NewConnectionsMinutelyRate
|
||||
if newConnectionsMinutelyRate <= 0 {
|
||||
newConnectionsMinutelyRate = nodeconfigs.DefaultTCPNewConnectionsMinutelyRate
|
||||
if newConnectionsMinutelyRate <= 0 {
|
||||
newConnectionsMinutelyRate = 100000
|
||||
}
|
||||
}
|
||||
var newConnectionsMinutelyRateBlockTimeout = tcpConfig.NewConnectionsMinutelyRateBlockTimeout
|
||||
if newConnectionsMinutelyRateBlockTimeout < 0 {
|
||||
newConnectionsMinutelyRateBlockTimeout = 0
|
||||
}
|
||||
|
||||
// new connections rate (secondly)
|
||||
var newConnectionsSecondlyRate = tcpConfig.NewConnectionsSecondlyRate
|
||||
if newConnectionsSecondlyRate <= 0 {
|
||||
newConnectionsSecondlyRate = nodeconfigs.DefaultTCPNewConnectionsSecondlyRate
|
||||
if newConnectionsSecondlyRate <= 0 {
|
||||
newConnectionsSecondlyRate = 10000
|
||||
}
|
||||
}
|
||||
var newConnectionsSecondlyRateBlockTimeout = tcpConfig.NewConnectionsSecondlyRateBlockTimeout
|
||||
if newConnectionsSecondlyRateBlockTimeout < 0 {
|
||||
newConnectionsSecondlyRateBlockTimeout = 0
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
var hasChanges = false
|
||||
@@ -215,7 +232,11 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
@@ -242,33 +263,61 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
// 添加新规则
|
||||
for _, port := range ports {
|
||||
if maxConnections > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if maxConnectionsPerIP > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
}
|
||||
|
||||
if newConnectionsRate > 0 {
|
||||
// TODO 思考是否有惩罚机制
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
// 超过一定速率就drop或者加入黑名单(分钟)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsMinutelyRate > 0 {
|
||||
if newConnectionsMinutelyRateBlockTimeout > 0 {
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
} else {
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 超过一定速率就drop或者加入黑名单(秒)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsSecondlyRate > 0 {
|
||||
if newConnectionsSecondlyRateBlockTimeout > 0 {
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
} else {
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,14 +384,14 @@ func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
|
||||
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
if len(pieces) < 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
switch pieces[2] {
|
||||
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
|
||||
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate", "newConnectionsSecondlyRate":
|
||||
err := chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -500,3 +549,8 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DDoSProtectionManager) nftExe() string {
|
||||
path, _ := exec.LookPath("nft")
|
||||
return path
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
}
|
||||
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
|
||||
47
internal/firewalls/firewall_base.go
Normal file
47
internal/firewalls/firewall_base.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaseFirewall struct {
|
||||
locker sync.Mutex
|
||||
latestIPTimes []string // [ip@time, ....]
|
||||
}
|
||||
|
||||
// 检查是否在最近添加过
|
||||
func (this *BaseFirewall) checkLatestIP(ip string) bool {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var expiredIndex = -1
|
||||
for index, ipTime := range this.latestIPTimes {
|
||||
var pieces = strings.Split(ipTime, "@")
|
||||
var oldIP = pieces[0]
|
||||
var oldTimestamp = pieces[1]
|
||||
if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
|
||||
expiredIndex = index
|
||||
continue
|
||||
}
|
||||
if oldIP == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if expiredIndex > -1 {
|
||||
this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
|
||||
}
|
||||
|
||||
this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
|
||||
const maxLen = 128
|
||||
if len(this.latestIPTimes) > maxLen {
|
||||
this.latestIPTimes = this.latestIPTimes[1:]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -4,27 +4,37 @@ package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type firewalldCmd struct {
|
||||
cmd *executils.Cmd
|
||||
denyIP string
|
||||
}
|
||||
|
||||
type Firewalld struct {
|
||||
BaseFirewall
|
||||
|
||||
isReady bool
|
||||
exe string
|
||||
cmdQueue chan *exec.Cmd
|
||||
cmdQueue chan *firewalldCmd
|
||||
}
|
||||
|
||||
func NewFirewalld() *Firewalld {
|
||||
var firewalld = &Firewalld{
|
||||
cmdQueue: make(chan *exec.Cmd, 4096),
|
||||
cmdQueue: make(chan *firewalldCmd, 4096),
|
||||
}
|
||||
|
||||
path, err := exec.LookPath("firewall-cmd")
|
||||
if err == nil && len(path) > 0 {
|
||||
var cmd = exec.Command(path, "--state")
|
||||
var cmd = executils.NewTimeoutCmd(3*time.Second, path, "--state")
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
firewalld.exe = path
|
||||
@@ -41,13 +51,19 @@ func NewFirewalld() *Firewalld {
|
||||
|
||||
func (this *Firewalld) init() {
|
||||
goman.New(func() {
|
||||
for cmd := range this.cmdQueue {
|
||||
for c := range this.cmdQueue {
|
||||
var cmd = c.cmd
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
if strings.HasPrefix(err.Error(), "Warning:") {
|
||||
continue
|
||||
}
|
||||
remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
|
||||
} else {
|
||||
// 关闭连接
|
||||
if len(c.denyIP) > 0 {
|
||||
conns.SharedMap.CloseIPConns(c.denyIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -71,8 +87,8 @@ func (this *Firewalld) AllowPort(port int, protocol string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -81,13 +97,13 @@ func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+port, "--permanent")
|
||||
this.pushCmd(cmd, "")
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+port)
|
||||
this.pushCmd(cmd, "")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,8 +114,8 @@ func (this *Firewalld) RemovePort(port int, protocol string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,13 +123,13 @@ func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol str
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+port, "--permanent")
|
||||
this.pushCmd(cmd, "")
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+port)
|
||||
this.pushCmd(cmd, "")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -131,6 +147,12 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 避免短时间内重复添加
|
||||
if this.checkLatestIP(ip) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
@@ -139,8 +161,8 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
if timeoutSeconds > 0 {
|
||||
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
|
||||
}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
|
||||
this.pushCmd(cmd, ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -148,6 +170,12 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 避免短时间内重复添加
|
||||
if async && this.checkLatestIP(ip) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
@@ -156,12 +184,15 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e
|
||||
if timeoutSeconds > 0 {
|
||||
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
|
||||
}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
|
||||
if async {
|
||||
this.pushCmd(cmd)
|
||||
this.pushCmd(cmd, ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
defer conns.SharedMap.CloseIPConns(ip)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("run command failed '" + cmd.String() + "': " + err.Error())
|
||||
@@ -173,21 +204,22 @@ func (this *Firewalld) RemoveSourceIP(ip string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
}
|
||||
for _, action := range []string{"reject", "drop"} {
|
||||
var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
|
||||
this.pushCmd(cmd, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) pushCmd(cmd *exec.Cmd) {
|
||||
func (this *Firewalld) pushCmd(cmd *executils.Cmd, denyIP string) {
|
||||
select {
|
||||
case this.cmdQueue <- cmd:
|
||||
case this.cmdQueue <- &firewalldCmd{cmd: cmd, denyIP: denyIP}:
|
||||
default:
|
||||
// we discard the command
|
||||
}
|
||||
|
||||
@@ -5,13 +5,14 @@
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
@@ -100,6 +101,8 @@ func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
BaseFirewall
|
||||
|
||||
conn *nftables.Conn
|
||||
isReady bool
|
||||
version string
|
||||
@@ -344,6 +347,14 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
// 尝试关闭连接
|
||||
conns.SharedMap.CloseIPConns(ip)
|
||||
|
||||
// 避免短时间内重复添加
|
||||
if async && this.checkLatestIP(ip) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if async {
|
||||
select {
|
||||
case this.dropIPQueue <- &blockIPItem{
|
||||
@@ -357,6 +368,9 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
|
||||
return nil
|
||||
}
|
||||
|
||||
// 再次尝试关闭连接
|
||||
defer conns.SharedMap.CloseIPConns(ip)
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
@@ -418,18 +432,49 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
|
||||
// 读取版本号
|
||||
func (this *NFTablesFirewall) readVersion(nftPath string) string {
|
||||
var cmd = exec.Command(nftPath, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nftPath, "--version")
|
||||
cmd.WithStdout()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var outputString = cmd.Stdout()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
return ""
|
||||
}
|
||||
return versionMatches[1]
|
||||
}
|
||||
|
||||
// 检查是否在最近添加过
|
||||
func (this *NFTablesFirewall) existLatestIP(ip string) bool {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var expiredIndex = -1
|
||||
for index, ipTime := range this.latestIPTimes {
|
||||
var pieces = strings.Split(ipTime, "@")
|
||||
var oldIP = pieces[0]
|
||||
var oldTimestamp = pieces[1]
|
||||
if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
|
||||
expiredIndex = index
|
||||
continue
|
||||
}
|
||||
if oldIP == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if expiredIndex > -1 {
|
||||
this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
|
||||
}
|
||||
|
||||
this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
|
||||
const maxLen = 128
|
||||
if len(this.latestIPTimes) > maxLen {
|
||||
this.latestIPTimes = this.latestIPTimes[1:]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
108
internal/firewalls/nftables/installer.go
Normal file
108
internal/firewalls/nftables/installer.go
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
// linux only
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
}
|
||||
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if nodeConfig == nil || !nodeConfig.AutoInstallNftables {
|
||||
return
|
||||
}
|
||||
|
||||
if os.Getgid() == 0 { // root user only
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
goman.New(func() {
|
||||
err := NewInstaller().Install()
|
||||
if err != nil {
|
||||
// 不需要传到API节点
|
||||
logs.Println("[NFTABLES]install nftables failed: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type Installer struct {
|
||||
}
|
||||
|
||||
func NewInstaller() *Installer {
|
||||
return &Installer{}
|
||||
}
|
||||
|
||||
func (this *Installer) Install() error {
|
||||
// linux only
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否已经存在
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *executils.Cmd
|
||||
|
||||
// check dnf
|
||||
dnfExe, err := exec.LookPath("dnf")
|
||||
if err == nil {
|
||||
cmd = executils.NewCmd(dnfExe, "-y", "install", "nftables")
|
||||
}
|
||||
|
||||
// check apt
|
||||
if cmd == nil {
|
||||
aptExe, err := exec.LookPath("apt")
|
||||
if err == nil {
|
||||
cmd = executils.NewCmd(aptExe, "install", "nftables")
|
||||
}
|
||||
}
|
||||
|
||||
// check yum
|
||||
if cmd == nil {
|
||||
yumExe, err := exec.LookPath("yum")
|
||||
if err == nil {
|
||||
cmd = executils.NewCmd(yumExe, "-y", "install", "nftables")
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.WithTimeout(10 * time.Minute)
|
||||
cmd.WithStderr()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New(err.Error() + ": " + cmd.Stderr())
|
||||
}
|
||||
|
||||
remotelogs.Println("NFTABLES", "installed nftables with command '"+cmd.String()+"' successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
|
||||
// FirewalldAction Firewalld动作管理
|
||||
// 常用命令:
|
||||
// - 查询列表: firewall-cmd --list-all
|
||||
// - 添加IP:firewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
|
||||
// - 删除IP:firewall-cmd --remove-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
|
||||
// - 查询列表: firewall-cmd --list-all
|
||||
// - 添加IP:firewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
|
||||
// - 删除IP:firewall-cmd --remove-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
|
||||
type FirewalldAction struct {
|
||||
BaseAction
|
||||
|
||||
@@ -144,12 +144,11 @@ func (this *FirewalldAction) runActionSingleIP(action string, listType IPListTyp
|
||||
// MAC OS直接返回
|
||||
return nil
|
||||
}
|
||||
cmd := exec.Command(path, args...)
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
cmd := executils.NewTimeoutCmd(30*time.Second, path, args...)
|
||||
cmd.WithStderr()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New(err.Error() + ", output: " + string(stderr.Bytes()))
|
||||
return errors.New(err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
// IPSetAction IPSet动作
|
||||
// 相关命令:
|
||||
// - 利用Firewalld管理set:
|
||||
// - 添加:firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
|
||||
// - 删除:firewall-cmd --permanent --delete-ipset=edge_ip_list
|
||||
// - 重载:firewall-cmd --reload
|
||||
// - firewalld+ipset: firewall-cmd --permanent --add-rich-rule="rule source ipset='edge_ip_list' reject"
|
||||
// - 添加:firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
|
||||
// - 删除:firewall-cmd --permanent --delete-ipset=edge_ip_list
|
||||
// - 重载:firewall-cmd --reload
|
||||
// - firewalld+ipset: firewall-cmd --permanent --add-rich-rule="rule source ipset='edge_ip_list' reject"
|
||||
// - 利用IPTables管理set:
|
||||
// - 添加:iptables -A INPUT -m set --match-set edge_ip_list src -j REJECT
|
||||
// - 添加:iptables -A INPUT -m set --match-set edge_ip_list src -j REJECT
|
||||
// - 添加Item:ipset add edge_ip_list 192.168.2.32 timeout 30
|
||||
// - 删除Item: ipset del edge_ip_list 192.168.2.32
|
||||
// - 创建set:ipset create edge_ip_list hash:ip timeout 0
|
||||
@@ -30,16 +30,13 @@ import (
|
||||
type IPSetAction struct {
|
||||
BaseAction
|
||||
|
||||
config *firewallconfigs.FirewallActionIPSetConfig
|
||||
errorBuf *bytes.Buffer
|
||||
config *firewallconfigs.FirewallActionIPSetConfig
|
||||
|
||||
ipsetNotfound bool
|
||||
}
|
||||
|
||||
func NewIPSetAction() *IPSetAction {
|
||||
return &IPSetAction{
|
||||
errorBuf: &bytes.Buffer{},
|
||||
}
|
||||
return &IPSetAction{}
|
||||
}
|
||||
|
||||
func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) error {
|
||||
@@ -68,14 +65,13 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
var cmd = exec.Command(path, "create", listName, "hash:ip", "timeout", "0", "maxelem", "1000000")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "timeout", "0", "maxelem", "1000000")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
if !bytes.Contains(output, []byte("already exists")) {
|
||||
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
var output = cmd.Stderr()
|
||||
if !strings.Contains(output, "already exists") {
|
||||
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + output)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
@@ -87,14 +83,13 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
var cmd = exec.Command(path, "create", listName, "hash:ip", "family", "inet6", "timeout", "0", "maxelem", "1000000")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "family", "inet6", "timeout", "0", "maxelem", "1000000")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
if !bytes.Contains(output, []byte("already exists")) {
|
||||
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
var output = cmd.Stderr()
|
||||
if !strings.Contains(output, "already exists") {
|
||||
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + output)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
@@ -114,16 +109,15 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
cmd := exec.Command(path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=timeout=0", "--option=maxelem=1000000")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=timeout=0", "--option=maxelem=1000000")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
if bytes.Contains(output, []byte("NAME_CONFLICT")) {
|
||||
var output = cmd.Stderr()
|
||||
if strings.Contains(output, "NAME_CONFLICT") {
|
||||
err = nil
|
||||
} else {
|
||||
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + output)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,16 +127,15 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
cmd := exec.Command(path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=family=inet6", "--option=timeout=0", "--option=maxelem=1000000")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=family=inet6", "--option=timeout=0", "--option=maxelem=1000000")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
if bytes.Contains(output, []byte("NAME_CONFLICT")) {
|
||||
var output = cmd.Stderr()
|
||||
if strings.Contains(output, "NAME_CONFLICT") {
|
||||
err = nil
|
||||
} else {
|
||||
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + output)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -152,13 +145,11 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
var cmd = exec.Command(path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' accept")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' accept")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,25 +158,21 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
if len(listName) == 0 {
|
||||
continue
|
||||
}
|
||||
var cmd = exec.Command(path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' reject")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' reject")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
}
|
||||
|
||||
// reload
|
||||
{
|
||||
cmd := exec.Command(path, "--reload")
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--reload")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
return errors.New("firewall-cmd reload: " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("firewall-cmd reload: " + err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -204,19 +191,17 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
}
|
||||
|
||||
// 检查规则是否存在
|
||||
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
|
||||
err := cmd.Run()
|
||||
var exists = err == nil
|
||||
|
||||
// 添加规则
|
||||
if !exists {
|
||||
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,18 +213,16 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
|
||||
}
|
||||
|
||||
// 检查规则是否存在
|
||||
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
|
||||
err := cmd.Run()
|
||||
var exists = err == nil
|
||||
|
||||
if !exists {
|
||||
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
|
||||
var stderr = bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
var output = stderr.Bytes()
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
|
||||
return errors.New("iptables add rule: " + err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -361,12 +344,11 @@ func (this *IPSetAction) runActionSingleIP(action string, listType IPListType, i
|
||||
return nil
|
||||
}
|
||||
|
||||
this.errorBuf.Reset()
|
||||
var cmd = exec.Command(path, args...)
|
||||
cmd.Stderr = this.errorBuf
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
|
||||
cmd.WithStderr()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
var errString = this.errorBuf.String()
|
||||
var errString = cmd.Stderr()
|
||||
if action == "deleteItem" && strings.Contains(errString, "not added") {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPTablesAction IPTables动作
|
||||
// 相关命令:
|
||||
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
|
||||
// iptables -A INPUT -s "192.168.2.32" -j REJECT
|
||||
// iptables -D INPUT ...
|
||||
// iptables -F INPUT
|
||||
//
|
||||
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
|
||||
// iptables -A INPUT -s "192.168.2.32" -j REJECT
|
||||
// iptables -D INPUT ...
|
||||
// iptables -F INPUT
|
||||
type IPTablesAction struct {
|
||||
BaseAction
|
||||
|
||||
@@ -110,16 +113,15 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command(path, args...)
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
|
||||
cmd.WithStderr()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
output := stderr.Bytes()
|
||||
if bytes.Contains(output, []byte("No chain/target/match")) {
|
||||
var output = cmd.Stderr()
|
||||
if strings.Contains(output, "No chain/target/match") {
|
||||
err = nil
|
||||
} else {
|
||||
return errors.New(err.Error() + ", output: " + string(output))
|
||||
return errors.New(err.Error() + ", output: " + output)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -68,7 +68,7 @@ func (this *ActionManager) UpdateActions(actions []*firewallconfigs.FirewallActi
|
||||
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
|
||||
continue
|
||||
}
|
||||
if bytes.Compare(newConfigJSON, oldConfigJSON) != 0 {
|
||||
if !bytes.Equal(newConfigJSON, oldConfigJSON) {
|
||||
_ = oldInstance.Close()
|
||||
|
||||
// 重新创建
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"os/exec"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ScriptAction 脚本命令动作
|
||||
@@ -45,25 +45,24 @@ func (this *ScriptAction) DeleteItem(listType IPListType, item *pb.IPItem) error
|
||||
|
||||
func (this *ScriptAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
|
||||
// TODO 智能支持 .sh 脚本文件
|
||||
cmd := exec.Command(this.config.Path)
|
||||
cmd.Env = []string{
|
||||
var cmd = executils.NewTimeoutCmd(30*time.Second, this.config.Path)
|
||||
cmd.WithEnv([]string{
|
||||
"ACTION=" + action,
|
||||
"TYPE=" + item.Type,
|
||||
"IP_FROM=" + item.IpFrom,
|
||||
"IP_TO=" + item.IpTo,
|
||||
"EXPIRED_AT=" + fmt.Sprintf("%d", item.ExpiredAt),
|
||||
"LIST_TYPE=" + listType,
|
||||
}
|
||||
})
|
||||
if len(this.config.Cwd) > 0 {
|
||||
cmd.Dir = this.config.Cwd
|
||||
cmd.WithDir(this.config.Cwd)
|
||||
} else {
|
||||
cmd.Dir = filepath.Dir(this.config.Path)
|
||||
cmd.WithDir(filepath.Dir(this.config.Path))
|
||||
}
|
||||
stderr := bytes.NewBuffer([]byte{})
|
||||
cmd.Stderr = stderr
|
||||
cmd.WithStderr()
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New(err.Error() + ", output: " + string(stderr.Bytes()))
|
||||
return errors.New(err.Error() + ", output: " + cmd.Stderr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
// 源码改自:https://github.com/lionsoul2014/ip2region/blob/master/binding/golang/ip2region/ip2Region.go
|
||||
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexBlockLength = 12
|
||||
)
|
||||
|
||||
var err error
|
||||
|
||||
type IP2Region struct {
|
||||
headerSip []int64
|
||||
headerPtr []int64
|
||||
headerLen int64
|
||||
|
||||
// super block index info
|
||||
firstIndexPtr int64
|
||||
lastIndexPtr int64
|
||||
totalBlocks int64
|
||||
|
||||
dbData []byte
|
||||
}
|
||||
|
||||
type IpInfo struct {
|
||||
CityId int64
|
||||
Country string
|
||||
Region string
|
||||
Province string
|
||||
City string
|
||||
ISP string
|
||||
}
|
||||
|
||||
func (ip IpInfo) String() string {
|
||||
return strconv.FormatInt(ip.CityId, 10) + "|" + ip.Country + "|" + ip.Region + "|" + ip.Province + "|" + ip.City + "|" + ip.ISP
|
||||
}
|
||||
|
||||
func getIpInfo(cityId int64, line []byte) *IpInfo {
|
||||
lineSlice := strings.Split(string(line), "|")
|
||||
ipInfo := &IpInfo{}
|
||||
length := len(lineSlice)
|
||||
ipInfo.CityId = cityId
|
||||
if length < 5 {
|
||||
for i := 0; i <= 5-length; i++ {
|
||||
lineSlice = append(lineSlice, "")
|
||||
}
|
||||
}
|
||||
|
||||
ipInfo.Country = lineSlice[0]
|
||||
ipInfo.Region = lineSlice[1]
|
||||
ipInfo.Province = lineSlice[2]
|
||||
ipInfo.City = lineSlice[3]
|
||||
ipInfo.ISP = lineSlice[4]
|
||||
return ipInfo
|
||||
}
|
||||
|
||||
func NewIP2Region(path string) (*IP2Region, error) {
|
||||
var region = &IP2Region{}
|
||||
region.dbData, err = os.ReadFile(path)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
region.firstIndexPtr = region.ipLongAtOffset(0)
|
||||
region.lastIndexPtr = region.ipLongAtOffset(4)
|
||||
region.totalBlocks = (region.lastIndexPtr-region.firstIndexPtr)/IndexBlockLength + 1
|
||||
return region, nil
|
||||
}
|
||||
|
||||
func (this *IP2Region) MemorySearch(ipStr string) (ipInfo *IpInfo, err error) {
|
||||
ip, err := ip2long(ipStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := this.totalBlocks
|
||||
var dataPtr, l int64
|
||||
for l <= h {
|
||||
m := (l + h) >> 1
|
||||
p := this.firstIndexPtr + m*IndexBlockLength
|
||||
sip := this.ipLongAtOffset(p)
|
||||
if ip < sip {
|
||||
h = m - 1
|
||||
} else {
|
||||
eip := this.ipLongAtOffset(p + 4)
|
||||
if ip > eip {
|
||||
l = m + 1
|
||||
} else {
|
||||
dataPtr = this.ipLongAtOffset(p + 8)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if dataPtr == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dataLen := (dataPtr >> 24) & 0xFF
|
||||
dataPtr = dataPtr & 0x00FFFFFF
|
||||
return getIpInfo(this.ipLongAtOffset(dataPtr), this.dbData[(dataPtr)+4:dataPtr+dataLen]), nil
|
||||
}
|
||||
|
||||
func (this *IP2Region) ipLongAtOffset(offset int64) int64 {
|
||||
return int64(this.dbData[offset]) |
|
||||
int64(this.dbData[offset+1])<<8 |
|
||||
int64(this.dbData[offset+2])<<16 |
|
||||
int64(this.dbData[offset+3])<<24
|
||||
}
|
||||
|
||||
func ip2long(IpStr string) (int64, error) {
|
||||
bits := strings.Split(IpStr, ".")
|
||||
if len(bits) != 4 {
|
||||
return 0, errors.New("ip format error")
|
||||
}
|
||||
|
||||
var sum int64
|
||||
for i, n := range bits {
|
||||
bit, _ := strconv.ParseInt(n, 10, 64)
|
||||
sum += bit << uint(24-8*i)
|
||||
}
|
||||
|
||||
return sum, nil
|
||||
}
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
type IPListDB struct {
|
||||
db *sql.DB
|
||||
|
||||
itemTableName string
|
||||
itemTableName string
|
||||
|
||||
deleteExpiredItemsStmt *sql.Stmt
|
||||
deleteItemStmt *sql.Stmt
|
||||
insertItemStmt *sql.Stmt
|
||||
@@ -53,7 +54,9 @@ func (this *IPListDB) init() error {
|
||||
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+this.dir+"/ip_list.db?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
|
||||
var path = this.dir + "/ip_list.db"
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+path+"?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -66,6 +69,14 @@ func (this *IPListDB) init() error {
|
||||
|
||||
this.db = db
|
||||
|
||||
// 恢复数据库
|
||||
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
|
||||
if len(recoverEnv) > 0 {
|
||||
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
|
||||
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -159,6 +170,12 @@ func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是删除,则不再创建新记录
|
||||
if item.IsDeleted {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
|
||||
return err
|
||||
}
|
||||
@@ -194,12 +211,12 @@ func (this *IPListDB) ReadMaxVersion() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
row := this.selectMaxVersionStmt.QueryRow()
|
||||
var row = this.selectMaxVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
}
|
||||
var version int64
|
||||
err = row.Scan(&version)
|
||||
err := row.Scan(&version)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestIPListDB_AddItem(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.AddItem(&pb.IPItem{
|
||||
Id: 1,
|
||||
IpFrom: "192.168.1.101",
|
||||
@@ -45,6 +46,12 @@ func TestIPListDB_AddItem(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
type LibraryInterface interface {
|
||||
// Load 加载数据库文件
|
||||
Load(dbPath string) error
|
||||
|
||||
// Lookup 查询IP
|
||||
// 返回结果有可能为空
|
||||
Lookup(ip string) (*Result, error)
|
||||
|
||||
// Close 关闭数据库文件
|
||||
Close()
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IP2RegionLibrary struct {
|
||||
db *IP2Region
|
||||
}
|
||||
|
||||
func (this *IP2RegionLibrary) Load(dbPath string) error {
|
||||
db, err := NewIP2Region(dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) {
|
||||
// 暂不支持IPv6
|
||||
if strings.Contains(ip, ":") {
|
||||
return nil, nil
|
||||
}
|
||||
if net.ParseIP(ip) == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if this.db == nil {
|
||||
return nil, errors.New("library has not been loaded")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// 防止panic发生
|
||||
err := recover()
|
||||
if err != nil {
|
||||
remotelogs.Error("IP2RegionLibrary", "panic: "+fmt.Sprintf("%#v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := this.db.MemorySearch(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if info.Country == "0" {
|
||||
info.Country = ""
|
||||
}
|
||||
if info.Region == "0" {
|
||||
info.Region = ""
|
||||
}
|
||||
if info.Province == "0" {
|
||||
info.Province = ""
|
||||
}
|
||||
if info.City == "0" {
|
||||
info.City = ""
|
||||
}
|
||||
if info.ISP == "0" {
|
||||
info.ISP = ""
|
||||
}
|
||||
|
||||
return &Result{
|
||||
CityId: info.CityId,
|
||||
Country: info.Country,
|
||||
Region: info.Region,
|
||||
Province: info.Province,
|
||||
City: info.City,
|
||||
ISP: info.ISP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (this *IP2RegionLibrary) Close() {
|
||||
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIP2RegionLibrary_Lookup_MemoryUsage(t *testing.T) {
|
||||
var mem = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(mem)
|
||||
|
||||
library := &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var mem2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(mem2)
|
||||
t.Log((mem2.HeapInuse-mem.HeapInuse)/1024/1024, "MB")
|
||||
}
|
||||
|
||||
func TestIP2RegionLibrary_Lookup_Single(t *testing.T) {
|
||||
library := &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, ip := range []string{"8.8.9.9"} {
|
||||
result, err := library.Lookup(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("IP:", ip, "result:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIP2RegionLibrary_Lookup(t *testing.T) {
|
||||
library := &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, ip := range []string{"", "a", "1.1.1", "192.168.1.100", "114.240.223.47", "8.8.9.9", "::1"} {
|
||||
result, err := library.Lookup(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("IP:", ip, "result:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIP2RegionLibrary_Lookup_Concurrent(t *testing.T) {
|
||||
library := &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var count = 4000
|
||||
var wg = sync.WaitGroup{}
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
_, _ = library.Lookup(strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestIP2RegionLibrary_Memory(t *testing.T) {
|
||||
library := &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
before := time.Now()
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
_, _ = library.Lookup(strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)))
|
||||
}
|
||||
|
||||
t.Log("cost:", time.Since(before).Seconds()*1000, "ms")
|
||||
}
|
||||
|
||||
func BenchmarkIP2RegionLibrary_Lookup(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var library = &IP2RegionLibrary{}
|
||||
err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = library.Lookup("8.8.8.8")
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/files"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var SharedManager = NewManager()
|
||||
var SharedLibrary LibraryInterface
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
// 初始化
|
||||
library, err := SharedManager.Load()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("IP_LIBRARY", err)
|
||||
return
|
||||
}
|
||||
SharedLibrary = library
|
||||
})
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
code string
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
func (this *Manager) Load() (LibraryInterface, error) {
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := nodeConfig.GlobalConfig
|
||||
if config == nil {
|
||||
config = &serverconfigs.GlobalConfig{}
|
||||
}
|
||||
|
||||
// 当前正在使用的IP库代号
|
||||
code := config.IPLibrary.Code
|
||||
if len(code) == 0 {
|
||||
code = serverconfigs.DefaultIPLibraryType
|
||||
}
|
||||
|
||||
dir := Tea.Root + "/resources/ipdata/" + code
|
||||
var lastVersion int64 = -1
|
||||
lastFilename := ""
|
||||
for _, file := range files.NewFile(dir).List() {
|
||||
filename := file.Name()
|
||||
|
||||
reg := regexp.MustCompile(`^` + regexp.QuoteMeta(code) + `.(\d+)\.`)
|
||||
if reg.MatchString(filename) { // 先查找有版本号的
|
||||
result := reg.FindStringSubmatch(filename)
|
||||
version := types.Int64(result[1])
|
||||
if version > lastVersion {
|
||||
lastVersion = version
|
||||
lastFilename = filename
|
||||
}
|
||||
} else if strings.HasPrefix(filename, code+".") { // 后查找默认的
|
||||
if lastVersion == -1 {
|
||||
lastFilename = filename
|
||||
lastVersion = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(lastFilename) == 0 {
|
||||
return nil, errors.New("ip library file not found")
|
||||
}
|
||||
|
||||
var libraryPtr LibraryInterface
|
||||
switch code {
|
||||
case serverconfigs.IPLibraryTypeIP2Region:
|
||||
libraryPtr = &IP2RegionLibrary{}
|
||||
default:
|
||||
return nil, errors.New("invalid ip library code '" + code + "'")
|
||||
}
|
||||
|
||||
err = libraryPtr.Load(dir + "/" + lastFilename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return libraryPtr, nil
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedCityManager = NewCityManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedCityManager.Start()
|
||||
})
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedCityManager.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// CityManager 中国省份信息管理
|
||||
type CityManager struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
cacheFile string
|
||||
|
||||
cityMap map[string]int64 // provinceName_cityName => cityName
|
||||
dataHash string // 国家JSON的md5
|
||||
|
||||
locker sync.RWMutex
|
||||
|
||||
isUpdated bool
|
||||
}
|
||||
|
||||
func NewCityManager() *CityManager {
|
||||
return &CityManager{
|
||||
cacheFile: Tea.Root + "/configs/region_city.json.cache",
|
||||
cityMap: map[string]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CityManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("CITY_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("City_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
this.ticker = time.NewTicker(4 * time.Hour)
|
||||
for range this.ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("CITY_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CityManager) Stop() {
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CityManager) Lookup(provinceId int64, cityName string) (cityId int64) {
|
||||
this.locker.RLock()
|
||||
cityId, _ = this.cityMap[types.String(provinceId)+"_"+cityName]
|
||||
this.locker.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *CityManager) load() error {
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := map[string]int64{}
|
||||
err = json.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m != nil && len(m) > 0 {
|
||||
this.cityMap = m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新城市信息
|
||||
func (this *CityManager) loop() error {
|
||||
if this.isUpdated {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionCityRPC().FindAllRegionCities(rpcClient.Context(), &pb.FindAllRegionCitiesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]int64{}
|
||||
for _, city := range resp.RegionCities {
|
||||
for _, code := range city.Codes {
|
||||
m[types.String(city.RegionProvinceId)+"_"+code] = city.Id
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有更新
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash := md5.New()
|
||||
hash.Write(data)
|
||||
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if this.dataHash == dataHash {
|
||||
return nil
|
||||
}
|
||||
this.dataHash = dataHash
|
||||
|
||||
this.locker.Lock()
|
||||
this.cityMap = m
|
||||
this.isUpdated = true
|
||||
this.locker.Unlock()
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewCityManager(t *testing.T) {
|
||||
var manager = NewCityManager()
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(manager.Lookup(16, "许昌市"))
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedCountryManager = NewCountryManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedCountryManager.Start()
|
||||
})
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedCountryManager.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// CountryManager 国家/地区信息管理
|
||||
type CountryManager struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
cacheFile string
|
||||
|
||||
countryMap map[string]int64 // countryName => countryId
|
||||
dataHash string // 国家JSON的md5
|
||||
|
||||
locker sync.RWMutex
|
||||
|
||||
isUpdated bool
|
||||
}
|
||||
|
||||
func NewCountryManager() *CountryManager {
|
||||
return &CountryManager{
|
||||
cacheFile: Tea.Root + "/configs/region_country.json.cache",
|
||||
countryMap: map[string]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CountryManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
this.ticker = time.NewTicker(4 * time.Hour)
|
||||
for range this.ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("COUNTRY_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CountryManager) Stop() {
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CountryManager) Lookup(countryName string) (countryId int64) {
|
||||
this.locker.RLock()
|
||||
countryId, _ = this.countryMap[countryName]
|
||||
this.locker.RUnlock()
|
||||
return countryId
|
||||
}
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *CountryManager) load() error {
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := map[string]int64{}
|
||||
err = json.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m != nil && len(m) > 0 {
|
||||
this.countryMap = m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新国家信息
|
||||
func (this *CountryManager) loop() error {
|
||||
if this.isUpdated {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionCountryRPC().FindAllRegionCountries(rpcClient.Context(), &pb.FindAllRegionCountriesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]int64{}
|
||||
for _, country := range resp.RegionCountries {
|
||||
for _, code := range country.Codes {
|
||||
m[code] = country.Id
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有更新
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash := md5.New()
|
||||
hash.Write(data)
|
||||
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if this.dataHash == dataHash {
|
||||
return nil
|
||||
}
|
||||
this.dataHash = dataHash
|
||||
|
||||
this.locker.Lock()
|
||||
this.countryMap = m
|
||||
this.isUpdated = true
|
||||
this.locker.Unlock()
|
||||
|
||||
// 保存到本地缓存
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCountryManager_load(t *testing.T) {
|
||||
manager := NewCountryManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok", manager.countryMap)
|
||||
}
|
||||
|
||||
func TestCountryManager_loop(t *testing.T) {
|
||||
manager := NewCountryManager()
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok", manager.countryMap)
|
||||
}
|
||||
|
||||
func TestCountryManager_loop_skip(t *testing.T) {
|
||||
manager := NewCountryManager()
|
||||
for i := 0; i < 10; i++ {
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountryManager_Lookup(t *testing.T) {
|
||||
manager := NewCountryManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(manager.Lookup("中国"), manager.Lookup("美国 "))
|
||||
}
|
||||
|
||||
func BenchmarkCountryManager_Lookup(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
manager := NewCountryManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = manager.Lookup("中国")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -18,6 +19,10 @@ var SharedIPListManager = NewIPListManager()
|
||||
var IPListUpdateNotify = make(chan bool, 1)
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedIPListManager.Start()
|
||||
@@ -26,6 +31,13 @@ func init() {
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedIPListManager.Stop()
|
||||
})
|
||||
|
||||
var ticker = time.NewTicker(24 * time.Hour)
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
SharedIPListManager.DeleteExpiredItems()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// IPListManager IP名单管理
|
||||
@@ -43,7 +55,7 @@ type IPListManager struct {
|
||||
|
||||
func NewIPListManager() *IPListManager {
|
||||
return &IPListManager{
|
||||
pageSize: 500,
|
||||
pageSize: 1000,
|
||||
listMap: map[int64]*IPList{},
|
||||
}
|
||||
}
|
||||
@@ -111,15 +123,19 @@ func (this *IPListManager) init() {
|
||||
var size int64 = 1000
|
||||
for {
|
||||
items, err := db.ReadItems(offset, size)
|
||||
var l = len(items)
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+err.Error())
|
||||
} else {
|
||||
if len(items) == 0 {
|
||||
if l == 0 {
|
||||
break
|
||||
}
|
||||
this.processItems(items, false)
|
||||
if int64(l) < size {
|
||||
break
|
||||
}
|
||||
}
|
||||
offset += int64(len(items))
|
||||
offset += int64(l)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -144,7 +160,7 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
itemsResp, err := rpcClient.IPItemRPC().ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{
|
||||
itemsResp, err := rpcClient.IPItemRPC.ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: this.pageSize,
|
||||
})
|
||||
@@ -155,7 +171,7 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
items := itemsResp.IpItems
|
||||
var items = itemsResp.IpItems
|
||||
if len(items) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
@@ -182,8 +198,13 @@ func (this *IPListManager) FindList(listId int64) *IPList {
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPListManager) DeleteExpiredItems() {
|
||||
if this.db != nil {
|
||||
_ = this.db.DeleteExpiredItems()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
this.locker.Lock()
|
||||
var changedLists = map[*IPList]zero.Zero{}
|
||||
for _, item := range items {
|
||||
var list *IPList
|
||||
@@ -203,11 +224,15 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
list = GlobalWhiteIPList
|
||||
}
|
||||
} else { // 其他List
|
||||
this.locker.Lock()
|
||||
list = this.listMap[item.ListId]
|
||||
this.locker.Unlock()
|
||||
}
|
||||
if list == nil {
|
||||
list = NewIPList()
|
||||
this.locker.Lock()
|
||||
this.listMap[item.ListId] = list
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
changedLists[list] = zero.New()
|
||||
@@ -246,8 +271,6 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
|
||||
changedList.Sort()
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
if fromRemote {
|
||||
var latestVersion = items[len(items)-1].Version
|
||||
if latestVersion > this.version {
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedProviderManager = NewProviderManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedProviderManager.Start()
|
||||
})
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedProviderManager.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// ProviderManager 中国省份信息管理
|
||||
type ProviderManager struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
cacheFile string
|
||||
|
||||
providerMap map[string]int64 // name => id
|
||||
dataHash string // 国家JSON的md5
|
||||
|
||||
locker sync.RWMutex
|
||||
|
||||
isUpdated bool
|
||||
}
|
||||
|
||||
func NewProviderManager() *ProviderManager {
|
||||
return &ProviderManager{
|
||||
cacheFile: Tea.Root + "/configs/region_provider.json.cache",
|
||||
providerMap: map[string]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProviderManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
this.ticker = time.NewTicker(4 * time.Hour)
|
||||
for range this.ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProviderManager) Stop() {
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProviderManager) Lookup(providerName string) (providerId int64) {
|
||||
this.locker.RLock()
|
||||
providerId, _ = this.providerMap[providerName]
|
||||
this.locker.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *ProviderManager) load() error {
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := map[string]int64{}
|
||||
err = json.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m != nil && len(m) > 0 {
|
||||
this.providerMap = m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新服务商信息
|
||||
func (this *ProviderManager) loop() error {
|
||||
if this.isUpdated {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionProviderRPC().FindAllRegionProviders(rpcClient.Context(), &pb.FindAllRegionProvidersRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]int64{}
|
||||
for _, provider := range resp.RegionProviders {
|
||||
for _, code := range provider.Codes {
|
||||
m[code] = provider.Id
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有更新
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash := md5.New()
|
||||
hash.Write(data)
|
||||
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if this.dataHash == dataHash {
|
||||
return nil
|
||||
}
|
||||
this.dataHash = dataHash
|
||||
|
||||
this.locker.Lock()
|
||||
this.providerMap = m
|
||||
this.isUpdated = true
|
||||
this.locker.Unlock()
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package iplibrary
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewProviderManager(t *testing.T) {
|
||||
var manager = NewProviderManager()
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(manager.Lookup("阿里云"))
|
||||
t.Log(manager.Lookup("阿里云2"))
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ChinaCountryId int64 = 1
|
||||
)
|
||||
|
||||
var SharedProvinceManager = NewProvinceManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedProvinceManager.Start()
|
||||
})
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedProvinceManager.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// ProvinceManager 中国省份信息管理
|
||||
type ProvinceManager struct {
|
||||
ticker *time.Ticker
|
||||
|
||||
cacheFile string
|
||||
|
||||
provinceMap map[string]int64 // provinceName => provinceId
|
||||
dataHash string // 国家JSON的md5
|
||||
|
||||
locker sync.RWMutex
|
||||
|
||||
isUpdated bool
|
||||
}
|
||||
|
||||
func NewProvinceManager() *ProvinceManager {
|
||||
return &ProvinceManager{
|
||||
cacheFile: Tea.Root + "/configs/region_province.json.cache",
|
||||
provinceMap: map[string]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProvinceManager) Start() {
|
||||
// 从缓存中读取
|
||||
err := this.load()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
|
||||
// 第一次更新
|
||||
err = this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
|
||||
// 定时更新
|
||||
this.ticker = time.NewTicker(4 * time.Hour)
|
||||
for range this.ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("PROVINCE_MANAGER", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProvinceManager) Stop() {
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ProvinceManager) Lookup(provinceName string) (provinceId int64) {
|
||||
this.locker.RLock()
|
||||
provinceId, _ = this.provinceMap[provinceName]
|
||||
this.locker.RUnlock()
|
||||
return provinceId
|
||||
}
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *ProvinceManager) load() error {
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := map[string]int64{}
|
||||
err = json.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if m != nil && len(m) > 0 {
|
||||
this.provinceMap = m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新省份信息
|
||||
func (this *ProvinceManager) loop() error {
|
||||
if this.isUpdated {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionProvinceRPC().FindAllRegionProvincesWithRegionCountryId(rpcClient.Context(), &pb.FindAllRegionProvincesWithRegionCountryIdRequest{
|
||||
RegionCountryId: ChinaCountryId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := map[string]int64{}
|
||||
for _, province := range resp.RegionProvinces {
|
||||
for _, code := range province.Codes {
|
||||
m[code] = province.Id
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有更新
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash := md5.New()
|
||||
hash.Write(data)
|
||||
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if this.dataHash == dataHash {
|
||||
return nil
|
||||
}
|
||||
this.dataHash = dataHash
|
||||
|
||||
this.locker.Lock()
|
||||
this.provinceMap = m
|
||||
this.isUpdated = true
|
||||
this.locker.Unlock()
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProvinceManager_load(t *testing.T) {
|
||||
manager := NewProvinceManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok", manager.provinceMap)
|
||||
}
|
||||
|
||||
func TestProvinceManager_loop(t *testing.T) {
|
||||
manager := NewProvinceManager()
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok", manager.provinceMap)
|
||||
}
|
||||
|
||||
func TestProvinceManager_loop_skip(t *testing.T) {
|
||||
manager := NewProvinceManager()
|
||||
for i := 0; i < 10; i++ {
|
||||
err := manager.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvinceManager_Lookup(t *testing.T) {
|
||||
manager := NewProvinceManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(manager.Lookup("安徽省"), manager.Lookup("北京市"))
|
||||
}
|
||||
|
||||
func BenchmarkProvinceManager_Lookup(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
manager := NewProvinceManager()
|
||||
err := manager.load()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = manager.Lookup("安徽省")
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager_Load(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
manager := NewManager()
|
||||
lib, err := manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(lib.Lookup("1.2.3.4"))
|
||||
t.Log(lib.Lookup("2.3.4.5"))
|
||||
t.Log(lib.Lookup("200.200.200.200"))
|
||||
t.Log(lib.Lookup("202.106.0.20"))
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
t.Log(SharedLibrary)
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedUpdater = NewUpdater()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventStart, func() {
|
||||
goman.New(func() {
|
||||
SharedUpdater.Start()
|
||||
})
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedUpdater.Stop()
|
||||
})
|
||||
}
|
||||
|
||||
// Updater IP库更新程序
|
||||
type Updater struct {
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
// NewUpdater 获取新对象
|
||||
func NewUpdater() *Updater {
|
||||
return &Updater{}
|
||||
}
|
||||
|
||||
// Start 开始更新
|
||||
func (this *Updater) Start() {
|
||||
// 这里不需要太频繁检查更新,因为通常不需要更新IP库
|
||||
this.ticker = time.NewTicker(1 * time.Hour)
|
||||
for range this.ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
remotelogs.ErrorObject("IP_LIBRARY", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Updater) Stop() {
|
||||
if this.ticker != nil {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// 单次任务
|
||||
func (this *Updater) loop() error {
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nodeConfig.GlobalConfig == nil {
|
||||
return nil
|
||||
}
|
||||
code := nodeConfig.GlobalConfig.IPLibrary.Code
|
||||
if len(code) == 0 {
|
||||
code = serverconfigs.DefaultIPLibraryType
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
libraryResp, err := rpcClient.IPLibraryRPC().FindLatestIPLibraryWithType(rpcClient.Context(), &pb.FindLatestIPLibraryWithTypeRequest{Type: code})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lib := libraryResp.IpLibrary
|
||||
if lib == nil || lib.File == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
typeInfo := serverconfigs.FindIPLibraryWithType(code)
|
||||
if typeInfo == nil {
|
||||
return errors.New("invalid ip library code '" + code + "'")
|
||||
}
|
||||
|
||||
path := Tea.Root + "/resources/ipdata/" + code + "/" + code + "." + fmt.Sprintf("%d", lib.CreatedAt) + typeInfo.GetString("ext")
|
||||
|
||||
// 是否已经存在
|
||||
_, err = os.Stat(path)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开始下载
|
||||
fileChunkIdsResp, err := rpcClient.FileChunkRPC().FindAllFileChunkIds(rpcClient.Context(), &pb.FindAllFileChunkIdsRequest{FileId: lib.File.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chunkIds := fileChunkIdsResp.FileChunkIds
|
||||
if len(chunkIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
isOk := false
|
||||
|
||||
fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// 如果保存不成功就直接删除
|
||||
if !isOk {
|
||||
_ = fp.Close()
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
}()
|
||||
for _, chunkId := range chunkIds {
|
||||
chunkResp, err := rpcClient.FileChunkRPC().DownloadFileChunk(rpcClient.Context(), &pb.DownloadFileChunkRequest{FileChunkId: chunkId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chunk := chunkResp.FileChunk
|
||||
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
_, err = fp.Write(chunk.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = fp.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 重新加载
|
||||
library, err := SharedManager.Load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
SharedLibrary = library
|
||||
|
||||
isOk = true
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpdater_loop(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
updater := NewUpdater()
|
||||
err := updater.loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package metrics
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -11,7 +12,15 @@ import (
|
||||
|
||||
var SharedManager = NewManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventQuit, func() {
|
||||
SharedManager.Quit()
|
||||
})
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
isQuiting bool
|
||||
|
||||
tasks map[int64]*Task // itemId => *Task
|
||||
categoryTasks map[string][]*Task // category => []*Task
|
||||
locker sync.RWMutex
|
||||
@@ -29,6 +38,10 @@ func NewManager() *Manager {
|
||||
}
|
||||
|
||||
func (this *Manager) Update(items []*serverconfigs.MetricItemConfig) {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
@@ -101,6 +114,10 @@ func (this *Manager) Update(items []*serverconfigs.MetricItemConfig) {
|
||||
|
||||
// Add 添加数据
|
||||
func (this *Manager) Add(obj MetricInterface) {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.RLock()
|
||||
for _, task := range this.categoryTasks[obj.MetricCategory()] {
|
||||
task.Add(obj)
|
||||
@@ -119,3 +136,17 @@ func (this *Manager) HasTCPMetrics() bool {
|
||||
func (this *Manager) HasUDPMetrics() bool {
|
||||
return this.hasUDPMetrics
|
||||
}
|
||||
|
||||
// Quit 退出管理器
|
||||
func (this *Manager) Quit() {
|
||||
this.isQuiting = true
|
||||
|
||||
remotelogs.Println("METRIC_MANAGER", "quit")
|
||||
|
||||
this.locker.Lock()
|
||||
for _, task := range this.tasks {
|
||||
_ = task.Stop()
|
||||
}
|
||||
this.tasks = map[int64]*Task{}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -89,13 +90,23 @@ func (this *Task) Init() error {
|
||||
remotelogs.Println("METRIC", "create data dir '"+dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+dir+"/metric."+strconv.FormatInt(this.item.Id, 10)+".db?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
|
||||
var path = dir + "/metric." + types.String(this.item.Id) + ".db"
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+path+"?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
this.db = dbs.NewDB(db)
|
||||
|
||||
// 恢复数据库
|
||||
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
|
||||
if len(recoverEnv) > 0 {
|
||||
for _, indexName := range []string{"serverId", "hash"} {
|
||||
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
|
||||
}
|
||||
}
|
||||
|
||||
if teaconst.EnableDBStat {
|
||||
this.db.EnableStat(true)
|
||||
}
|
||||
@@ -424,7 +435,7 @@ func (this *Task) Upload(pauseDuration time.Duration) error {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = rpcClient.MetricStatRPC().UploadMetricStats(rpcClient.Context(), &pb.UploadMetricStatsRequest{
|
||||
_, err = rpcClient.MetricStatRPC.UploadMetricStats(rpcClient.Context(), &pb.UploadMetricStatsRequest{
|
||||
MetricStats: pbStats,
|
||||
Time: currentTime,
|
||||
ServerId: serverId,
|
||||
|
||||
@@ -69,7 +69,7 @@ func (this *ValueQueue) Loop() error {
|
||||
}
|
||||
|
||||
for value := range this.valuesChan {
|
||||
_, err = rpcClient.NodeValueRPC().CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{
|
||||
_, err = rpcClient.NodeValueRPC.CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{
|
||||
Item: value.Item,
|
||||
ValueJSON: value.ValueJSON,
|
||||
CreatedAt: value.CreatedAt,
|
||||
|
||||
@@ -16,9 +16,9 @@ func TestValueQueue_RPC(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rpcClient.NodeValueRPC().CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{})
|
||||
_, err = rpcClient.NodeValueRPC.CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{})
|
||||
if err != nil {
|
||||
statusErr, ok:= status.FromError(err)
|
||||
statusErr, ok := status.FromError(err)
|
||||
if ok {
|
||||
logs.Println(statusErr.Code())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -17,7 +16,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/url"
|
||||
@@ -73,7 +72,7 @@ func (this *APIStream) loop() error {
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
nodeStream, err := rpcClient.NodeRPC().NodeStream(ctx)
|
||||
nodeStream, err := rpcClient.NodeRPC.NodeStream(ctx)
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
return nil
|
||||
@@ -347,15 +346,15 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
var shortName = teaconst.SystemdServiceName
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if output == "enabled" {
|
||||
if cmd.Stdout() == "enabled" {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "not installed")
|
||||
@@ -385,16 +384,15 @@ func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) e
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version")
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var outputString = cmd.Stdout()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
|
||||
@@ -5,55 +5,57 @@ package nodes
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientConn 客户端连接
|
||||
type ClientConn struct {
|
||||
once sync.Once
|
||||
BaseClientConn
|
||||
|
||||
isTLS bool
|
||||
hasDeadline bool
|
||||
hasRead bool
|
||||
|
||||
isLO bool // 是否为环路
|
||||
isLO bool // 是否为环路
|
||||
isInAllowList bool
|
||||
|
||||
hasResetSYNFlood bool
|
||||
|
||||
BaseClientConn
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
|
||||
if quickClose {
|
||||
// TCP
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
// TODO 可以在配置中设置此值
|
||||
_ = tcpConn.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
|
||||
// 是否为环路
|
||||
var remoteAddr = conn.RemoteAddr().String()
|
||||
var remoteAddr = rawConn.RemoteAddr().String()
|
||||
var isLO = strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:")
|
||||
|
||||
return &ClientConn{
|
||||
BaseClientConn: BaseClientConn{rawConn: conn},
|
||||
var conn = &ClientConn{
|
||||
BaseClientConn: BaseClientConn{rawConn: rawConn},
|
||||
isTLS: isTLS,
|
||||
isLO: isLO,
|
||||
isInAllowList: isInAllowList,
|
||||
}
|
||||
|
||||
if quickClose {
|
||||
// TODO 可以在配置中设置此值
|
||||
_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
|
||||
// 加入到Map
|
||||
conns.SharedMap.Add(conn)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
@@ -86,20 +88,24 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 检测是否为握手错误
|
||||
var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
|
||||
if isHandshakeError {
|
||||
_ = this.SetLinger(0)
|
||||
}
|
||||
|
||||
// SYN Flood检测
|
||||
if this.serverId == 0 || !this.hasResetSYNFlood {
|
||||
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
|
||||
if synFloodConfig != nil && synFloodConfig.IsOn {
|
||||
if isHandshakeError {
|
||||
this.increaseSYNFlood(synFloodConfig)
|
||||
} else if err == nil && !this.hasResetSYNFlood {
|
||||
this.hasResetSYNFlood = true
|
||||
this.resetSYNFlood()
|
||||
// 忽略白名单和局域网
|
||||
if !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
|
||||
// SYN Flood检测
|
||||
if this.serverId == 0 || !this.hasResetSYNFlood {
|
||||
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
|
||||
if synFloodConfig != nil && synFloodConfig.IsOn {
|
||||
if isHandshakeError {
|
||||
this.increaseSYNFlood(synFloodConfig)
|
||||
} else if err == nil && !this.hasResetSYNFlood {
|
||||
this.hasResetSYNFlood = true
|
||||
this.resetSYNFlood()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -110,11 +116,10 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
n, err = this.rawConn.Write(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
|
||||
|
||||
// 统计当前服务带宽
|
||||
if this.serverId > 0 {
|
||||
if !this.isLO { // 环路不统计带宽,避免缓存预热等行为产生带宽
|
||||
if !this.isLO || Tea.IsTesting() { // 环路不统计带宽,避免缓存预热等行为产生带宽
|
||||
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
|
||||
stats.SharedBandwidthStatManager.Add(this.userId, this.serverId, int64(n))
|
||||
}
|
||||
}
|
||||
@@ -132,6 +137,9 @@ func (this *ClientConn) Close() error {
|
||||
// 不能加条件限制,因为服务配置随时有变化
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
// 从conn map中移除
|
||||
conns.SharedMap.Remove(this)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -177,6 +185,11 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
|
||||
if timeout <= 0 {
|
||||
timeout = 600
|
||||
}
|
||||
|
||||
// 关闭当前连接
|
||||
_ = this.SetLinger(0)
|
||||
_ = this.Close()
|
||||
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, time.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击,当前1分钟"+types.String(result)+"次空连接")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
package nodes
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
type BaseClientConn struct {
|
||||
rawConn net.Conn
|
||||
@@ -14,6 +17,8 @@ type BaseClientConn struct {
|
||||
hasLimit bool
|
||||
|
||||
isClosed bool
|
||||
|
||||
rawIP string
|
||||
}
|
||||
|
||||
func (this *BaseClientConn) IsClosed() bool {
|
||||
@@ -42,6 +47,17 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS
|
||||
// SetServerId 设置服务ID
|
||||
func (this *BaseClientConn) SetServerId(serverId int64) {
|
||||
this.serverId = serverId
|
||||
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
nativeConn, ok := conn.NetConn().(ClientConnInterface)
|
||||
if ok {
|
||||
nativeConn.SetServerId(serverId)
|
||||
}
|
||||
case *ClientConn:
|
||||
conn.SetServerId(serverId)
|
||||
}
|
||||
}
|
||||
|
||||
// ServerId 读取当前连接绑定的服务ID
|
||||
@@ -52,6 +68,17 @@ func (this *BaseClientConn) ServerId() int64 {
|
||||
// SetUserId 设置所属服务的用户ID
|
||||
func (this *BaseClientConn) SetUserId(userId int64) {
|
||||
this.userId = userId
|
||||
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
nativeConn, ok := conn.NetConn().(ClientConnInterface)
|
||||
if ok {
|
||||
nativeConn.SetUserId(userId)
|
||||
}
|
||||
case *ClientConn:
|
||||
conn.SetUserId(userId)
|
||||
}
|
||||
}
|
||||
|
||||
// UserId 获取当前连接所属服务的用户ID
|
||||
@@ -61,14 +88,30 @@ func (this *BaseClientConn) UserId() int64 {
|
||||
|
||||
// RawIP 原本IP
|
||||
func (this *BaseClientConn) RawIP() string {
|
||||
if len(this.rawIP) > 0 {
|
||||
return this.rawIP
|
||||
}
|
||||
|
||||
ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
|
||||
this.rawIP = ip
|
||||
return ip
|
||||
}
|
||||
|
||||
// TCPConn 转换为TCPConn
|
||||
func (this *BaseClientConn) TCPConn() (*net.TCPConn, bool) {
|
||||
conn, ok := this.rawConn.(*net.TCPConn)
|
||||
return conn, ok
|
||||
func (this *BaseClientConn) TCPConn() (tcpConn *net.TCPConn, ok bool) {
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
var internalConn = conn.NetConn()
|
||||
clientConn, ok := internalConn.(*ClientConn)
|
||||
if ok {
|
||||
return clientConn.TCPConn()
|
||||
}
|
||||
tcpConn, ok = internalConn.(*net.TCPConn)
|
||||
default:
|
||||
tcpConn, ok = this.rawConn.(*net.TCPConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetLinger 设置Linger
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
@@ -40,12 +41,32 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
var isInAllowList = false
|
||||
if err == nil {
|
||||
canGoNext, _ := iplibrary.AllowIP(ip, 0)
|
||||
var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
canGoNext, inAllowList := iplibrary.AllowIP(ip, 0)
|
||||
isInAllowList = inAllowList
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout > 0 {
|
||||
canGoNext = false
|
||||
|
||||
if !canGoNext || beingDenied {
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canGoNext {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
@@ -53,19 +74,11 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
|
||||
_ = conn.Close()
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
if beingDenied {
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
_ = fw.DropSourceIP(ip, 120, true)
|
||||
}
|
||||
}
|
||||
|
||||
return this.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose), nil
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose, isInAllowList), nil
|
||||
}
|
||||
|
||||
func (this *ClientListener) Close() error {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -96,7 +98,7 @@ Loop:
|
||||
this.rpcClient = client
|
||||
}
|
||||
|
||||
_, err := this.rpcClient.HTTPAccessLogRPC().CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
_, err := this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
if err != nil {
|
||||
// 是否包含了invalid UTF-8
|
||||
if strings.Contains(err.Error(), "string field contains invalid UTF-8") {
|
||||
@@ -105,7 +107,20 @@ Loop:
|
||||
}
|
||||
|
||||
// 重新提交
|
||||
_, err = this.rpcClient.HTTPAccessLogRPC().CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
return err
|
||||
}
|
||||
|
||||
// 是否请求内容过大
|
||||
statusCode, ok := status.FromError(err)
|
||||
if ok && statusCode.Code() == codes.ResourceExhausted {
|
||||
// 去除Body
|
||||
for _, accessLog := range accessLogs {
|
||||
accessLog.RequestBody = nil
|
||||
}
|
||||
|
||||
// 重新提交
|
||||
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestHTTPAccessLogQueue_Push(t *testing.T) {
|
||||
// logs.PrintAsJSON(accessLog)
|
||||
|
||||
//t.Log(strings.ToValidUTF8(string(utf8Bytes), ""))
|
||||
_, err = client.HTTPAccessLogRPC().CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
accessLog,
|
||||
}})
|
||||
if err != nil {
|
||||
@@ -99,7 +99,7 @@ func TestHTTPAccessLogQueue_Push2(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = client.HTTPAccessLogRPC().CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
accessLog,
|
||||
}})
|
||||
if err != nil {
|
||||
|
||||
@@ -81,7 +81,7 @@ func (this *HTTPCacheTaskManager) Start() {
|
||||
|
||||
if rpcClient != nil {
|
||||
for taskReq := range this.taskQueue {
|
||||
_, err := rpcClient.ServerRPC().PurgeServerCache(rpcClient.Context(), taskReq)
|
||||
_, err := rpcClient.ServerRPC.PurgeServerCache(rpcClient.Context(), taskReq)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "create purge task failed: "+err.Error())
|
||||
}
|
||||
@@ -104,8 +104,12 @@ func (this *HTTPCacheTaskManager) Loop() error {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.HTTPCacheTaskKeyRPC().FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
|
||||
resp, err := rpcClient.HTTPCacheTaskKeyRPC.FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
|
||||
if err != nil {
|
||||
// 忽略连接错误
|
||||
if rpc.IsConnError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -131,7 +135,7 @@ func (this *HTTPCacheTaskManager) Loop() error {
|
||||
pbResults = append(pbResults, pbResult)
|
||||
}
|
||||
|
||||
_, err = rpcClient.HTTPCacheTaskKeyRPC().UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
|
||||
_, err = rpcClient.HTTPCacheTaskKeyRPC.UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
@@ -68,6 +68,7 @@ type HTTPRequest struct {
|
||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||
origin *serverconfigs.OriginConfig // 源站
|
||||
originAddr string // 源站实际地址
|
||||
originStatus int32 // 源站响应代码
|
||||
errors []string // 错误信息
|
||||
rewriteRule *serverconfigs.HTTPRewriteRule // 匹配到的重写规则
|
||||
rewriteReplace string // 重写规则的目标
|
||||
@@ -228,6 +229,14 @@ func (this *HTTPRequest) Do() {
|
||||
}
|
||||
}
|
||||
|
||||
// 防盗链
|
||||
if !this.isSubRequest && this.web.Referers != nil && this.web.Referers.IsOn {
|
||||
if this.doCheckReferers() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 访问控制
|
||||
if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn {
|
||||
if this.doAuth() {
|
||||
@@ -512,6 +521,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
this.web.Auth = web.Auth
|
||||
}
|
||||
|
||||
// referers
|
||||
if web.Referers != nil && (web.Referers.IsPrior || isTop) {
|
||||
this.web.Referers = web.Referers
|
||||
}
|
||||
|
||||
// request limit
|
||||
if web.RequestLimit != nil && (web.RequestLimit.IsPrior || isTop) {
|
||||
this.web.RequestLimit = web.RequestLimit
|
||||
@@ -677,7 +691,7 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
case "remoteAddrValue":
|
||||
return this.requestRemoteAddr(false)
|
||||
case "rawRemoteAddr":
|
||||
addr := this.RawReq.RemoteAddr
|
||||
var addr = this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
addr = host
|
||||
@@ -928,42 +942,47 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
|
||||
// geo
|
||||
if prefix == "geo" {
|
||||
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
|
||||
var result = iplib.LookupIP(this.requestRemoteAddr(true))
|
||||
|
||||
switch suffix {
|
||||
case "country.name":
|
||||
if result != nil {
|
||||
return result.Country
|
||||
if result != nil && result.IsOk() {
|
||||
return result.CountryName()
|
||||
}
|
||||
return ""
|
||||
case "country.id":
|
||||
if result != nil {
|
||||
return types.String(iplibrary.SharedCountryManager.Lookup(result.Country))
|
||||
if result != nil && result.IsOk() {
|
||||
return types.String(result.CountryId())
|
||||
}
|
||||
return "0"
|
||||
case "province.name":
|
||||
if result != nil {
|
||||
return result.Province
|
||||
if result != nil && result.IsOk() {
|
||||
return result.ProvinceName()
|
||||
}
|
||||
return ""
|
||||
case "province.id":
|
||||
if result != nil {
|
||||
return types.String(iplibrary.SharedProvinceManager.Lookup(result.Province))
|
||||
if result != nil && result.IsOk() {
|
||||
return types.String(result.ProvinceId())
|
||||
}
|
||||
return "0"
|
||||
case "city.name":
|
||||
if result != nil {
|
||||
return result.City
|
||||
if result != nil && result.IsOk() {
|
||||
return result.CityName()
|
||||
}
|
||||
return ""
|
||||
case "city.id":
|
||||
if result != nil {
|
||||
var provinceId = iplibrary.SharedProvinceManager.Lookup(result.Province)
|
||||
if provinceId > 0 {
|
||||
return types.String(iplibrary.SharedCityManager.Lookup(provinceId, result.City))
|
||||
} else {
|
||||
return "0"
|
||||
}
|
||||
if result != nil && result.IsOk() {
|
||||
return types.String(result.CityId())
|
||||
}
|
||||
return "0"
|
||||
case "town.name":
|
||||
if result != nil && result.IsOk() {
|
||||
return result.TownName()
|
||||
}
|
||||
return ""
|
||||
case "town.id":
|
||||
if result != nil && result.IsOk() {
|
||||
return types.String(result.TownId())
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
@@ -971,16 +990,16 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
|
||||
// ips
|
||||
if prefix == "isp" {
|
||||
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
|
||||
var result = iplib.LookupIP(this.requestRemoteAddr(true))
|
||||
|
||||
switch suffix {
|
||||
case "name":
|
||||
if result != nil {
|
||||
return result.ISP
|
||||
if result != nil && result.IsOk() {
|
||||
return result.ProviderName()
|
||||
}
|
||||
case "id":
|
||||
if result != nil {
|
||||
return types.String(iplibrary.SharedProviderManager.Lookup(result.ISP))
|
||||
if result != nil && result.IsOk() {
|
||||
return types.String(result.ProviderId())
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
@@ -1098,7 +1117,7 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
|
||||
}
|
||||
|
||||
// Remote-Addr
|
||||
remoteAddr := this.RawReq.RemoteAddr
|
||||
var remoteAddr = this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err == nil {
|
||||
if supportVar {
|
||||
@@ -1167,7 +1186,7 @@ func (this *HTTPRequest) requestRemoteUser() string {
|
||||
|
||||
// Path 请求的URL中路径部分
|
||||
func (this *HTTPRequest) Path() string {
|
||||
uri, err := url.ParseRequestURI(this.rawURI)
|
||||
uri, err := url.ParseRequestURI(this.uri)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -1315,7 +1334,7 @@ func (this *HTTPRequest) RemoteAddr() string {
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) RawRemoteAddr() string {
|
||||
addr := this.RawReq.RemoteAddr
|
||||
var addr = this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
addr = host
|
||||
@@ -1423,18 +1442,21 @@ func (this *HTTPRequest) Done() {
|
||||
func (this *HTTPRequest) Close() {
|
||||
this.Done()
|
||||
|
||||
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lingerConn, ok := requestConn.(LingerConn)
|
||||
if ok {
|
||||
_ = lingerConn.SetLinger(0)
|
||||
}
|
||||
|
||||
conn, ok := requestConn.(net.Conn)
|
||||
if ok {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Allow 放行
|
||||
@@ -1589,9 +1611,7 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
|
||||
}
|
||||
|
||||
// 处理自定义Response Header
|
||||
func (this *HTTPRequest) processResponseHeaders(statusCode int) {
|
||||
var responseHeader = this.writer.Header()
|
||||
|
||||
func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
|
||||
// 删除/添加/替换Header
|
||||
// TODO 实现AddTrailers
|
||||
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {
|
||||
@@ -1721,7 +1741,7 @@ func (this *HTTPRequest) canIgnore(err error) bool {
|
||||
}
|
||||
|
||||
// HTTP/2流错误
|
||||
if err.Error() == "http2: stream closed" || err.Error() == "client disconnected" { // errStreamClosed, errClientDisconnected
|
||||
if err.Error() == "http2: stream closed" || strings.Contains(err.Error(), "stream error") || err.Error() == "client disconnected" { // errStreamClosed, errClientDisconnected
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func (this *HTTPRequest) doACME() (shouldStop bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
keyResp, err := rpcClient.ACMEAuthenticationRPC().FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
|
||||
keyResp, err := rpcClient.ACMEAuthenticationRPC.FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
|
||||
if err != nil {
|
||||
remotelogs.Error("RPC", "[ACME]read key for token failed: "+err.Error())
|
||||
return false
|
||||
|
||||
@@ -19,7 +19,10 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
if !ref.IsOn || ref.AuthPolicy == nil || !ref.AuthPolicy.IsOn {
|
||||
continue
|
||||
}
|
||||
b, err := ref.AuthPolicy.Filter(this.RawReq, func(subReq *http.Request) (status int, err error) {
|
||||
if !ref.AuthPolicy.MatchRequest(this.RawReq) {
|
||||
continue
|
||||
}
|
||||
ok, newURI, uriChanged, err := ref.AuthPolicy.Filter(this.RawReq, func(subReq *http.Request) (status int, err error) {
|
||||
subReq.TLS = this.RawReq.TLS
|
||||
subReq.RemoteAddr = this.RawReq.RemoteAddr
|
||||
subReq.Host = this.RawReq.Host
|
||||
@@ -36,24 +39,32 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to execute the AuthPolicy", "认证策略执行失败", false)
|
||||
return
|
||||
}
|
||||
if b {
|
||||
if ok {
|
||||
if uriChanged {
|
||||
this.uri = newURI
|
||||
}
|
||||
this.tags = append(this.tags, ref.AuthPolicy.Type)
|
||||
return
|
||||
} else {
|
||||
// Basic Auth比较特殊
|
||||
if ref.AuthPolicy.Type == serverconfigs.HTTPAuthTypeBasicAuth {
|
||||
var method = ref.AuthPolicy.Method().(*serverconfigs.HTTPAuthBasicMethod)
|
||||
var headerValue = "Basic realm=\""
|
||||
if len(method.Realm) > 0 {
|
||||
headerValue += method.Realm
|
||||
} else {
|
||||
headerValue += this.ReqHost
|
||||
method, ok := ref.AuthPolicy.Method().(*serverconfigs.HTTPAuthBasicMethod)
|
||||
if ok {
|
||||
var headerValue = "Basic realm=\""
|
||||
if len(method.Realm) > 0 {
|
||||
headerValue += method.Realm
|
||||
} else {
|
||||
headerValue += this.ReqHost
|
||||
}
|
||||
headerValue += "\""
|
||||
if len(method.Charset) > 0 {
|
||||
headerValue += ", charset=\"" + method.Charset + "\""
|
||||
}
|
||||
this.writer.Header()["WWW-Authenticate"] = []string{headerValue}
|
||||
}
|
||||
headerValue += "\""
|
||||
if len(method.Charset) > 0 {
|
||||
headerValue += ", charset=\"" + method.Charset + "\""
|
||||
}
|
||||
this.writer.Header()["WWW-Authenticate"] = []string{headerValue}
|
||||
}
|
||||
this.writer.WriteHeader(http.StatusUnauthorized)
|
||||
this.tags = append(this.tags, ref.AuthPolicy.Type)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,12 +44,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
// 检查服务独立的缓存条件
|
||||
refType := ""
|
||||
for _, cacheRef := range this.web.Cache.CacheRefs {
|
||||
if !cacheRef.IsOn ||
|
||||
cacheRef.Conds == nil ||
|
||||
!cacheRef.Conds.HasRequestConds() {
|
||||
if !cacheRef.IsOn {
|
||||
continue
|
||||
}
|
||||
if cacheRef.Conds.MatchRequest(this.Format) {
|
||||
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
|
||||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
|
||||
if cacheRef.IsReverse {
|
||||
return
|
||||
}
|
||||
@@ -61,12 +60,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
if this.cacheRef == nil && !this.web.Cache.DisablePolicyRefs {
|
||||
// 检查策略默认的缓存条件
|
||||
for _, cacheRef := range cachePolicy.CacheRefs {
|
||||
if !cacheRef.IsOn ||
|
||||
cacheRef.Conds == nil ||
|
||||
!cacheRef.Conds.HasRequestConds() {
|
||||
if !cacheRef.IsOn {
|
||||
continue
|
||||
}
|
||||
if cacheRef.Conds.MatchRequest(this.Format) {
|
||||
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
|
||||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
|
||||
if cacheRef.IsReverse {
|
||||
return
|
||||
}
|
||||
@@ -157,7 +155,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
for _, subKey := range subKeys {
|
||||
err := storage.Delete(subKey)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,7 +260,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
}
|
||||
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -322,7 +320,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
})
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -374,7 +372,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
// 支持 If-None-Match
|
||||
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
|
||||
// 自定义Header
|
||||
this.processResponseHeaders(http.StatusNotModified)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
this.isCached = true
|
||||
@@ -386,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
// 支持 If-Modified-Since
|
||||
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
|
||||
// 自定义Header
|
||||
this.processResponseHeaders(http.StatusNotModified)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
this.isCached = true
|
||||
@@ -395,7 +393,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
this.processResponseHeaders(reader.Status())
|
||||
this.processResponseHeaders(this.writer.Header(), reader.Status())
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
|
||||
// 返回上级节点过期时间
|
||||
@@ -424,7 +422,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
if supportRange {
|
||||
if len(rangeHeader) > 0 {
|
||||
if fileSize == 0 {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -432,7 +430,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
if len(ranges) == 0 {
|
||||
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -441,7 +439,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
for k, r := range ranges {
|
||||
r2, ok := r.Convert(fileSize)
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -468,12 +466,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
if err == caches.ErrInvalidRange {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -519,7 +517,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
})
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -556,7 +554,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -25,10 +25,10 @@ const httpStatusPageTemplate = `<!DOCTYPE html>
|
||||
</html>`
|
||||
|
||||
func (this *HTTPRequest) write404() {
|
||||
this.writeCode(http.StatusNotFound)
|
||||
this.writeCode(http.StatusNotFound, "", "")
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) writeCode(statusCode int) {
|
||||
func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage string) {
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
@@ -42,12 +42,22 @@ func (this *HTTPRequest) writeCode(statusCode int) {
|
||||
case "requestId":
|
||||
return this.requestId
|
||||
case "message":
|
||||
return "" // 空
|
||||
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
|
||||
if len(acceptLanguages) > 0 {
|
||||
var index = strings.Index(acceptLanguages, ",")
|
||||
if index > 0 {
|
||||
var firstLanguage = acceptLanguages[:index]
|
||||
if firstLanguage == "zh-CN" {
|
||||
return zhMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
return enMessage
|
||||
}
|
||||
return "${" + varName + "}"
|
||||
})
|
||||
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
@@ -100,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
|
||||
return "${" + varName + "}"
|
||||
})
|
||||
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
|
||||
@@ -187,7 +187,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
|
||||
|
||||
// 响应Header
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.processResponseHeaders(resp.StatusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
|
||||
// 准备
|
||||
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
|
||||
|
||||
@@ -34,10 +34,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(http.StatusTemporaryRedirect)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
this.processResponseHeaders(u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
}
|
||||
return true
|
||||
@@ -81,10 +81,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(http.StatusTemporaryRedirect)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
this.processResponseHeaders(u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
}
|
||||
return true
|
||||
@@ -104,10 +104,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(http.StatusTemporaryRedirect)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
this.processResponseHeaders(u.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -18,7 +18,7 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
// TODO 处理分片提交的内容
|
||||
if this.web.RequestLimit.MaxBodyBytes() > 0 &&
|
||||
this.RawReq.ContentLength > this.web.RequestLimit.MaxBodyBytes() {
|
||||
this.writeCode(http.StatusRequestEntityTooLarge)
|
||||
this.writeCode(http.StatusRequestEntityTooLarge, "", "")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok && !clientConn.IsBound() {
|
||||
if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
|
||||
this.writeCode(http.StatusTooManyRequests)
|
||||
this.writeCode(http.StatusTooManyRequests, "", "")
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -146,6 +146,7 @@ func (this *HTTPRequest) log() {
|
||||
if this.origin != nil {
|
||||
accessLog.OriginId = this.origin.Id
|
||||
accessLog.OriginAddress = this.originAddr
|
||||
accessLog.OriginStatus = this.originStatus
|
||||
}
|
||||
|
||||
// 请求Body
|
||||
|
||||
@@ -34,17 +34,7 @@ func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
|
||||
}
|
||||
return this.RawReq.ContentLength + hl, true
|
||||
case "${countConnection}":
|
||||
metricNewConnMapLocker.Lock()
|
||||
_, ok := metricNewConnMap[this.RawReq.RemoteAddr]
|
||||
if ok {
|
||||
delete(metricNewConnMap, this.RawReq.RemoteAddr)
|
||||
}
|
||||
metricNewConnMapLocker.Unlock()
|
||||
if ok {
|
||||
return 1, true
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
return 1, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func (this *HTTPRequest) doMismatch() {
|
||||
}
|
||||
|
||||
// 根据配置进行相应的处理
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
if sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
// 检查cc
|
||||
// TODO 可以在管理端配置是否开启以及最多尝试次数
|
||||
if len(remoteIP) > 0 {
|
||||
@@ -46,7 +46,7 @@ func (this *HTTPRequest) doMismatch() {
|
||||
}
|
||||
|
||||
// 处理当前连接
|
||||
var httpAllConfig = sharedNodeConfig.GlobalConfig.HTTPAll
|
||||
var httpAllConfig = sharedNodeConfig.GlobalServerConfig.HTTPAll
|
||||
var mismatchAction = httpAllConfig.DomainMismatchAction
|
||||
if mismatchAction != nil && mismatchAction.Code == "page" {
|
||||
if mismatchAction.Options != nil {
|
||||
|
||||
@@ -60,11 +60,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
|
||||
// 修改状态码
|
||||
if page.NewStatus > 0 {
|
||||
// 自定义响应Headers
|
||||
this.processResponseHeaders(page.NewStatus)
|
||||
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
|
||||
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
|
||||
this.writer.WriteHeader(page.NewStatus)
|
||||
} else {
|
||||
this.processResponseHeaders(status)
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
this.writer.Prepare(nil, stat.Size(), status, true)
|
||||
this.writer.WriteHeader(status)
|
||||
}
|
||||
@@ -99,11 +99,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
|
||||
// 修改状态码
|
||||
if page.NewStatus > 0 {
|
||||
// 自定义响应Headers
|
||||
this.processResponseHeaders(page.NewStatus)
|
||||
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
|
||||
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
|
||||
this.writer.WriteHeader(page.NewStatus)
|
||||
} else {
|
||||
this.processResponseHeaders(status)
|
||||
this.processResponseHeaders(this.writer.Header(), status)
|
||||
this.writer.Prepare(nil, int64(len(content)), status, true)
|
||||
this.writer.WriteHeader(status)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
|
||||
this.tags = append(this.tags, "plan")
|
||||
|
||||
var statusCode = http.StatusNotFound
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))
|
||||
|
||||
@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
|
||||
}
|
||||
|
||||
newURL := "https://" + host + this.RawReq.RequestURI
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
|
||||
|
||||
return true
|
||||
|
||||
45
internal/nodes/http_request_referers.go
Normal file
45
internal/nodes/http_request_referers.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
|
||||
if this.web.Referers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var refererURL = this.RawReq.Header.Get("Referer")
|
||||
if len(refererURL) == 0 {
|
||||
if this.web.Referers.MatchDomain(this.ReqHost, "") {
|
||||
return
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(refererURL)
|
||||
if err != nil {
|
||||
if this.web.Referers.MatchDomain(this.ReqHost, "") {
|
||||
return
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if !this.web.Referers.MatchDomain(this.ReqHost, u.Host) {
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
return true
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
// 处理Scheme
|
||||
if origin.Addr == nil {
|
||||
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
var originHostIndex = strings.Index(originAddr, ":")
|
||||
if originHostIndex < 0 {
|
||||
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
|
||||
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
|
||||
return
|
||||
}
|
||||
@@ -240,24 +240,18 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
|
||||
// 判断是否为Websocket请求
|
||||
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||
this.doWebsocket(requestHost)
|
||||
shouldRetry = this.doWebsocket(requestHost, isLastRetry)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取请求客户端
|
||||
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
|
||||
return
|
||||
}
|
||||
|
||||
// 在HTTP/2下需要防止因为requestBody而导致Content-Length为空的问题
|
||||
if this.RawReq.ProtoMajor == 2 && this.RawReq.ContentLength == 0 && this.RawReq.Body != nil {
|
||||
_ = this.RawReq.Body.Close()
|
||||
this.RawReq.Body = nil
|
||||
}
|
||||
|
||||
// 开始请求
|
||||
resp, err := client.Do(this.RawReq)
|
||||
if err != nil {
|
||||
@@ -268,7 +262,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
|
||||
} else if httpErr.Err != context.Canceled {
|
||||
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
@@ -284,7 +278,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
}
|
||||
|
||||
if httpErr.Err != io.EOF {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
}
|
||||
|
||||
return
|
||||
@@ -298,7 +292,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
if httpErr.Err != io.EOF {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
// 是否为客户端方面的错误
|
||||
@@ -326,6 +320,11 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 记录相关数据
|
||||
this.originStatus = int32(resp.StatusCode)
|
||||
|
||||
// 恢复源站状态
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
@@ -337,7 +336,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
if this.doWAFResponse(resp) {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing Error (WAF): "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing Error (WAF): "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -347,7 +346,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error (Page): "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error (Page): "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -398,7 +397,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
|
||||
// 响应Header
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.processResponseHeaders(resp.StatusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
|
||||
// 是否需要刷新
|
||||
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"
|
||||
@@ -449,13 +448,13 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
||||
var closeErr = resp.Body.Close()
|
||||
if closeErr != nil {
|
||||
if !this.canIgnore(closeErr) {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
|
||||
this.addError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
|
||||
// 跳转
|
||||
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
|
||||
if this.rewriteRule.RedirectStatus > 0 {
|
||||
this.processResponseHeaders(this.rewriteRule.RedirectStatus)
|
||||
this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
|
||||
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusTemporaryRedirect)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
// 支持 If-None-Match
|
||||
if this.requestHeader("If-None-Match") == eTag {
|
||||
// 自定义Header
|
||||
this.processResponseHeaders(http.StatusNotModified)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
return true
|
||||
}
|
||||
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
// 支持 If-Modified-Since
|
||||
if this.requestHeader("If-Modified-Since") == modifiedTime {
|
||||
// 自定义Header
|
||||
this.processResponseHeaders(http.StatusNotModified)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
return true
|
||||
}
|
||||
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
var contentRange = this.RawReq.Header.Get("Range")
|
||||
if len(contentRange) > 0 {
|
||||
if fileSize == 0 {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
|
||||
set, ok := httpRequestParseRangeHeader(contentRange)
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
for k, r := range ranges {
|
||||
r2, ok := r.Convert(fileSize)
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
|
||||
// 自定义Header
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
|
||||
// 在Range请求中不能缓存
|
||||
if len(ranges) > 0 {
|
||||
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
|
||||
if len(shutdown.URL) == 0 {
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.processResponseHeaders(shutdown.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
_, err := this.writer.WriteString("The site have been shutdown.")
|
||||
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
|
||||
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.processResponseHeaders(shutdown.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
buf := utils.BytePool1k.Get()
|
||||
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
|
||||
} else if shutdown.BodyType == shared.BodyTypeHTML {
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.processResponseHeaders(shutdown.Status)
|
||||
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
|
||||
this.tags = append(this.tags, "bandwidth")
|
||||
|
||||
var statusCode = 509
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
|
||||
this.writer.WriteHeader(statusCode)
|
||||
if len(config.NoticePageBody) != 0 {
|
||||
|
||||
@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
|
||||
|
||||
// Header
|
||||
if statusCode <= 0 {
|
||||
this.processResponseHeaders(resp.StatusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
} else {
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.processResponseHeaders(this.writer.Header(), statusCode)
|
||||
}
|
||||
|
||||
if supportVariables {
|
||||
|
||||
@@ -2,6 +2,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -161,56 +162,48 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
|
||||
// 检查地区封禁
|
||||
if firewallPolicy.Mode == firewallconfigs.FirewallModeDefend {
|
||||
if iplibrary.SharedLibrary != nil {
|
||||
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
|
||||
regionConfig := firewallPolicy.Inbound.Region
|
||||
if regionConfig.IsNotEmpty() {
|
||||
for _, remoteAddr := range remoteAddrs {
|
||||
result, err := iplibrary.SharedLibrary.Lookup(remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", "iplibrary lookup failed: "+err.Error())
|
||||
} else if result != nil {
|
||||
// 检查国家级别封禁
|
||||
if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 {
|
||||
countryId := iplibrary.SharedCountryManager.Lookup(result.Country)
|
||||
if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
|
||||
regionConfig := firewallPolicy.Inbound.Region
|
||||
if regionConfig.IsNotEmpty() {
|
||||
for _, remoteAddr := range remoteAddrs {
|
||||
var result = iplib.LookupIP(remoteAddr)
|
||||
if result != nil && result.IsOk() {
|
||||
// 检查国家/地区级别封禁
|
||||
var countryId = result.CountryId()
|
||||
if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
this.writeCode(http.StatusForbidden)
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
this.writeCode(http.StatusForbidden, "", "")
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyCountry")
|
||||
}
|
||||
|
||||
return true, false
|
||||
}
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyCountry")
|
||||
}
|
||||
|
||||
// 检查省份封禁
|
||||
if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 {
|
||||
var provinceId = iplibrary.SharedProvinceManager.Lookup(result.Province)
|
||||
if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
return true, false
|
||||
}
|
||||
|
||||
this.writeCode(http.StatusForbidden)
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
// 检查省份封禁
|
||||
var provinceId = result.ProvinceId()
|
||||
if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyProvince")
|
||||
}
|
||||
this.writeCode(http.StatusForbidden, "", "")
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
|
||||
return true, false
|
||||
}
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyProvince")
|
||||
}
|
||||
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,7 +218,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
}
|
||||
|
||||
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
|
||||
if forceLog && logRequestBody && hasRequestBody {
|
||||
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
if err != nil {
|
||||
@@ -301,7 +294,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
}
|
||||
|
||||
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
|
||||
if forceLog && logRequestBody && hasRequestBody {
|
||||
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
if err != nil {
|
||||
@@ -379,6 +372,8 @@ func (this *HTTPRequest) WAFServerId() int64 {
|
||||
// WAFClose 关闭连接
|
||||
func (this *HTTPRequest) WAFClose() {
|
||||
this.Close()
|
||||
|
||||
// 这里不要强关IP所有连接,避免因为单个服务而影响所有
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
// 处理Websocket请求
|
||||
func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
|
||||
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.addError(errors.New("websocket have not been enabled yet"))
|
||||
@@ -43,13 +44,16 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
|
||||
if isLastRetry {
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
|
||||
}
|
||||
|
||||
// 增加失败次数
|
||||
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
shouldRetry = true
|
||||
return
|
||||
}
|
||||
|
||||
@@ -79,6 +83,33 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// 读取第一个响应
|
||||
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
this.processResponseHeaders(resp.Header, resp.StatusCode)
|
||||
|
||||
// 将响应写回客户端
|
||||
err = resp.Write(clientConn)
|
||||
if err != nil {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// 复制剩余的数据
|
||||
var buf = utils.BytePool4k.Get()
|
||||
defer utils.BytePool4k.Put(buf)
|
||||
for {
|
||||
@@ -98,4 +129,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
_, _ = io.Copy(originConn, clientConn)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
103
internal/nodes/ip_library_updater.go
Normal file
103
internal/nodes/ip_library_updater.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
type IPLibraryUpdater struct {
|
||||
}
|
||||
|
||||
func NewIPLibraryUpdater() *IPLibraryUpdater {
|
||||
return &IPLibraryUpdater{}
|
||||
}
|
||||
|
||||
// DataDir 文件目录
|
||||
func (this *IPLibraryUpdater) DataDir() string {
|
||||
// data/
|
||||
var dir = Tea.Root + "/data"
|
||||
stat, err := os.Stat(dir)
|
||||
if err == nil && stat.IsDir() {
|
||||
return dir
|
||||
}
|
||||
|
||||
err = os.Mkdir(dir, 0666)
|
||||
if err == nil {
|
||||
return dir
|
||||
}
|
||||
|
||||
remotelogs.Error("IP_LIBRARY_UPDATER", "create directory '"+dir+"' failed: "+err.Error())
|
||||
|
||||
// 如果不能创建 data/ 目录,那么使用临时目录
|
||||
return os.TempDir()
|
||||
}
|
||||
|
||||
// FindLatestFile 检查最新的IP库文件
|
||||
func (this *IPLibraryUpdater) FindLatestFile() (code string, fileId int64, err error) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
resp, err := rpcClient.IPLibraryArtifactRPC.FindPublicIPLibraryArtifact(rpcClient.Context(), &pb.FindPublicIPLibraryArtifactRequest{})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
var artifact = resp.IpLibraryArtifact
|
||||
if artifact == nil {
|
||||
return
|
||||
}
|
||||
return artifact.Code, artifact.FileId, nil
|
||||
}
|
||||
|
||||
// DownloadFile 下载文件
|
||||
func (this *IPLibraryUpdater) DownloadFile(fileId int64, writer io.Writer) error {
|
||||
if fileId <= 0 {
|
||||
return errors.New("invalid fileId: " + types.String(fileId))
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunkIdsResp, err := rpcClient.FileChunkRPC.FindAllFileChunkIds(rpcClient.Context(), &pb.FindAllFileChunkIdsRequest{FileId: fileId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, chunkId := range chunkIdsResp.FileChunkIds {
|
||||
chunkResp, err := rpcClient.FileChunkRPC.DownloadFileChunk(rpcClient.Context(), &pb.DownloadFileChunkRequest{FileChunkId: chunkId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var chunk = chunkResp.FileChunk
|
||||
if chunk == nil {
|
||||
return errors.New("can not find file chunk with chunk id '" + types.String(chunkId) + "'")
|
||||
}
|
||||
_, err = writer.Write(chunk.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogInfo 普通日志
|
||||
func (this *IPLibraryUpdater) LogInfo(message string) {
|
||||
remotelogs.Println("IP_LIBRARY_UPDATER", message)
|
||||
}
|
||||
|
||||
// LogError 错误日志
|
||||
func (this *IPLibraryUpdater) LogError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
remotelogs.Error("IP_LIBRARY_UPDATER", err.Error())
|
||||
}
|
||||
@@ -7,14 +7,16 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
group *serverconfigs.ServerAddressGroup
|
||||
isListening bool
|
||||
listener ListenerInterface // 监听器
|
||||
group *serverconfigs.ServerAddressGroup
|
||||
listener ListenerInterface // 监听器
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
@@ -118,18 +120,64 @@ func (this *Listener) listenTCP() error {
|
||||
}
|
||||
|
||||
func (this *Listener) listenUDP() error {
|
||||
listener, err := this.createUDPListener()
|
||||
var addr = this.group.Addr()
|
||||
|
||||
var ipv4PacketListener *ipv4.PacketConn
|
||||
var ipv6PacketListener *ipv6.PacketConn
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(host) == 0 {
|
||||
// ipv4
|
||||
ipv4Listener, err := this.createUDPIPv4Listener()
|
||||
if err == nil {
|
||||
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
|
||||
// ipv6
|
||||
ipv6Listener, err := this.createUDPIPv6Listener()
|
||||
if err == nil {
|
||||
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
} else if strings.Contains(host, ":") { // ipv6
|
||||
ipv6Listener, err := this.createUDPIPv6Listener()
|
||||
if err == nil {
|
||||
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
} else { // ipv4
|
||||
ipv4Listener, err := this.createUDPIPv4Listener()
|
||||
if err == nil {
|
||||
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
|
||||
_ = listener.Close()
|
||||
|
||||
if ipv4PacketListener != nil {
|
||||
_ = ipv4PacketListener.Close()
|
||||
}
|
||||
|
||||
if ipv6PacketListener != nil {
|
||||
_ = ipv6PacketListener.Close()
|
||||
}
|
||||
})
|
||||
|
||||
this.listener = &UDPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: listener,
|
||||
IPv4Listener: ipv4PacketListener,
|
||||
IPv6Listener: ipv6PacketListener,
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
@@ -168,12 +216,20 @@ func (this *Listener) createTCPListener() (net.Listener, error) {
|
||||
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
|
||||
}
|
||||
|
||||
// 创建UDP监听器
|
||||
func (this *Listener) createUDPListener() (*net.UDPConn, error) {
|
||||
// TODO 将来支持udp4/udp6
|
||||
// 创建UDP IPv4监听器
|
||||
func (this *Listener) createUDPIPv4Listener() (*net.UDPConn, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUDP("udp", addr)
|
||||
return net.ListenUDP("udp4", addr)
|
||||
}
|
||||
|
||||
// 创建UDP监听器
|
||||
func (this *Listener) createUDPIPv6Listener() (*net.UDPConn, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUDP("udp6", addr)
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ package nodes
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
)
|
||||
|
||||
type BaseListener struct {
|
||||
@@ -75,7 +76,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
// 如果域名为空,则取第一个
|
||||
// 通常域名为空是因为是直接通过IP访问的
|
||||
if len(domain) == 0 {
|
||||
if group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
if group.IsHTTPS() && sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
|
||||
@@ -131,19 +132,19 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
||||
return
|
||||
}
|
||||
|
||||
var matchDomainStrictly = sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
if sharedNodeConfig.GlobalConfig != nil &&
|
||||
len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 &&
|
||||
(!matchDomainStrictly || lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name)) {
|
||||
defaultDomain := sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain
|
||||
if sharedNodeConfig.GlobalServerConfig != nil &&
|
||||
len(sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain) > 0 &&
|
||||
(!matchDomainStrictly || configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) || (sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP && net.ParseIP(name) != nil)) {
|
||||
var defaultDomain = sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain
|
||||
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
|
||||
if serverConfig != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if matchDomainStrictly && !lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name) {
|
||||
if matchDomainStrictly && !configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP || net.ParseIP(name) == nil) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,7 +171,7 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
}
|
||||
|
||||
// 是否严格匹配域名
|
||||
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
// 如果只有一个server,则默认为这个
|
||||
var currentServers = group.Servers()
|
||||
@@ -181,23 +182,3 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 使用CNAME来查找服务
|
||||
// TODO 防止单IP随机生成域名攻击
|
||||
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
|
||||
if !sharedNodeConfig.SupportCNAME {
|
||||
return nil
|
||||
}
|
||||
|
||||
var realName = sharedCNAMEManager.Lookup(domain)
|
||||
if len(realName) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
group := this.Group
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return group.MatchServerCNAME(realName)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"golang.org/x/net/http2"
|
||||
"io"
|
||||
@@ -13,14 +12,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var httpErrorLogger = log.New(io.Discard, "", 0)
|
||||
var metricNewConnMap = map[string]zero.Zero{} // remoteAddr => bool
|
||||
var metricNewConnMapLocker = &sync.Mutex{}
|
||||
|
||||
type contextKey struct {
|
||||
key string
|
||||
@@ -55,23 +51,10 @@ func (this *HTTPListener) Serve() error {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
|
||||
// 为指标存储连接信息
|
||||
if sharedNodeConfig.HasHTTPConnectionMetrics() {
|
||||
metricNewConnMapLocker.Lock()
|
||||
metricNewConnMap[conn.RemoteAddr().String()] = zero.New()
|
||||
metricNewConnMapLocker.Unlock()
|
||||
}
|
||||
case http.StateActive, http.StateIdle, http.StateHijacked:
|
||||
// Nothing to do
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
|
||||
// 移除指标存储连接信息
|
||||
// 因为中途配置可能有改变,所以暂时不添加条件
|
||||
metricNewConnMapLocker.Lock()
|
||||
delete(metricNewConnMap, conn.RemoteAddr().String())
|
||||
metricNewConnMapLocker.Unlock()
|
||||
}
|
||||
},
|
||||
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
@@ -9,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
@@ -213,15 +213,14 @@ func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) st
|
||||
option = "u"
|
||||
}
|
||||
|
||||
var cmd = exec.Command(path, "-"+option+"lpn", "sport = :"+port)
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, path, "-"+option+"lpn", "sport = :"+port)
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(output.String())
|
||||
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(cmd.Stdout())
|
||||
if len(matches) > 1 {
|
||||
return matches[1]
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,10 +21,57 @@ const (
|
||||
UDPConnLifeSeconds = 30
|
||||
)
|
||||
|
||||
type UDPPacketListener interface {
|
||||
ReadFrom(b []byte) (n int, cm any, src net.Addr, err error)
|
||||
WriteTo(b []byte, cm any, dst net.Addr) (n int, err error)
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
type UDPIPv4Listener struct {
|
||||
rawListener *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener {
|
||||
return &UDPIPv4Listener{rawListener: rawListener}
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
|
||||
return this.rawListener.ReadFrom(b)
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
|
||||
return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) LocalAddr() net.Addr {
|
||||
return this.rawListener.LocalAddr()
|
||||
}
|
||||
|
||||
type UDPIPv6Listener struct {
|
||||
rawListener *ipv6.PacketConn
|
||||
}
|
||||
|
||||
func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener {
|
||||
return &UDPIPv6Listener{rawListener: rawListener}
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
|
||||
return this.rawListener.ReadFrom(b)
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
|
||||
return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) LocalAddr() net.Addr {
|
||||
return this.rawListener.LocalAddr()
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
BaseListener
|
||||
|
||||
Listener *net.UDPConn
|
||||
IPv4Listener *ipv4.PacketConn
|
||||
IPv6Listener *ipv6.PacketConn
|
||||
|
||||
connMap map[string]*UDPConn
|
||||
connLocker sync.Mutex
|
||||
@@ -36,6 +85,60 @@ type UDPListener struct {
|
||||
}
|
||||
|
||||
func (this *UDPListener) Serve() error {
|
||||
if this.Group == nil {
|
||||
return nil
|
||||
}
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
var serverId = server.Id
|
||||
|
||||
var wg = &sync.WaitGroup{}
|
||||
wg.Add(2) // 2 = ipv4 + ipv6
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if this.IPv4Listener != nil {
|
||||
err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true)
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener))
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if this.IPv6Listener != nil {
|
||||
err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true)
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener))
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
|
||||
// 获取分组端口
|
||||
var groupAddr = this.Group.Addr()
|
||||
var portIndex = strings.LastIndex(groupAddr, ":")
|
||||
@@ -67,7 +170,7 @@ func (this *UDPListener) Serve() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, addr, err := this.Listener.ReadFrom(buffer)
|
||||
n, cm, clientAddr, err := listener.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
@@ -77,14 +180,14 @@ func (this *UDPListener) Serve() error {
|
||||
|
||||
if n > 0 {
|
||||
this.connLocker.Lock()
|
||||
conn, ok := this.connMap[addr.String()]
|
||||
conn, ok := this.connMap[clientAddr.String()]
|
||||
this.connLocker.Unlock()
|
||||
if ok && !conn.IsOk() {
|
||||
_ = conn.Close()
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
|
||||
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||
continue
|
||||
@@ -93,9 +196,9 @@ func (this *UDPListener) Serve() error {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
|
||||
continue
|
||||
}
|
||||
conn = NewUDPConn(firstServer, addr, this.Listener, originConn.(*net.UDPConn))
|
||||
conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn))
|
||||
this.connLocker.Lock()
|
||||
this.connMap[addr.String()] = conn
|
||||
this.connMap[clientAddr.String()] = conn
|
||||
this.connLocker.Unlock()
|
||||
}
|
||||
_, _ = conn.Write(buffer[:n])
|
||||
@@ -117,7 +220,26 @@ func (this *UDPListener) Close() error {
|
||||
}
|
||||
this.connLocker.Unlock()
|
||||
|
||||
return this.Listener.Close()
|
||||
var errorStrings = []string{}
|
||||
if this.IPv4Listener != nil {
|
||||
err := this.IPv4Listener.Close()
|
||||
if err != nil {
|
||||
errorStrings = append(errorStrings, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if this.IPv6Listener != nil {
|
||||
err := this.IPv6Listener.Close()
|
||||
if err != nil {
|
||||
errorStrings = append(errorStrings, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errorStrings) > 0 {
|
||||
return errors.New(errorStrings[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
@@ -132,7 +254,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.reverseProxy = firstServer.ReverseProxy
|
||||
}
|
||||
|
||||
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -181,12 +303,12 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
||||
if strings.Contains(remoteAddr.String(), "[") {
|
||||
transportProtocol = proxyproto.UDPv6
|
||||
}
|
||||
header := proxyproto.Header{
|
||||
var header = proxyproto.Header{
|
||||
Version: byte(reverseProxy.ProxyProtocol.Version),
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: transportProtocol,
|
||||
SourceAddr: remoteAddr,
|
||||
DestinationAddr: this.Listener.LocalAddr(),
|
||||
DestinationAddr: localAddr,
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
if err != nil {
|
||||
@@ -224,21 +346,21 @@ func (this *UDPListener) gcConns() {
|
||||
|
||||
// UDPConn 自定义的UDP连接管理
|
||||
type UDPConn struct {
|
||||
addr net.Addr
|
||||
proxyConn net.Conn
|
||||
serverConn net.Conn
|
||||
activatedAt int64
|
||||
isOk bool
|
||||
isClosed bool
|
||||
addr net.Addr
|
||||
proxyListener UDPPacketListener
|
||||
serverConn net.Conn
|
||||
activatedAt int64
|
||||
isOk bool
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn {
|
||||
var conn = &UDPConn{
|
||||
addr: addr,
|
||||
proxyConn: proxyConn,
|
||||
serverConn: serverConn,
|
||||
activatedAt: time.Now().Unix(),
|
||||
isOk: true,
|
||||
addr: addr,
|
||||
proxyListener: proxyListener,
|
||||
serverConn: serverConn,
|
||||
activatedAt: time.Now().Unix(),
|
||||
isOk: true,
|
||||
}
|
||||
|
||||
// 统计
|
||||
@@ -246,6 +368,14 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
// 处理ControlMessage
|
||||
switch controlMessage := cm.(type) {
|
||||
case *ipv4.ControlMessage:
|
||||
controlMessage.Src = controlMessage.Dst
|
||||
case *ipv6.ControlMessage:
|
||||
controlMessage.Src = controlMessage.Dst
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
var buffer = utils.BytePool4k.Get()
|
||||
defer func() {
|
||||
@@ -256,7 +386,8 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
|
||||
n, err := serverConn.Read(buffer)
|
||||
if n > 0 {
|
||||
conn.activatedAt = time.Now().Unix()
|
||||
_, writingErr := proxyConn.WriteTo(buffer[:n], addr)
|
||||
|
||||
_, writingErr := proxyListener.WriteTo(buffer[:n], cm, addr)
|
||||
if writingErr != nil {
|
||||
conn.isOk = false
|
||||
break
|
||||
|
||||
@@ -2,14 +2,17 @@ package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
@@ -21,6 +24,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
@@ -78,7 +82,7 @@ func (this *Node) Test() error {
|
||||
if err != nil {
|
||||
return errors.New("test rpc failed: " + err.Error())
|
||||
}
|
||||
_, err = rpcClient.APINodeRPC().FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
|
||||
_, err = rpcClient.APINodeRPC.FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
|
||||
if err != nil {
|
||||
return errors.New("test rpc failed: " + err.Error())
|
||||
}
|
||||
@@ -88,6 +92,10 @@ func (this *Node) Test() error {
|
||||
|
||||
// Start 启动
|
||||
func (this *Node) Start() {
|
||||
// 设置netdns
|
||||
// 这个需要放在所有网络访问的最前面
|
||||
_ = os.Setenv("GODEBUG", "netdns=go")
|
||||
|
||||
_, ok := os.LookupEnv("EdgeDaemon")
|
||||
if ok {
|
||||
remotelogs.Println("NODE", "start from daemon")
|
||||
@@ -101,9 +109,6 @@ func (this *Node) Start() {
|
||||
// 监听signal
|
||||
this.listenSignals()
|
||||
|
||||
// 启动事件
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 本地Sock
|
||||
err := this.listenSock()
|
||||
if err != nil {
|
||||
@@ -111,9 +116,19 @@ func (this *Node) Start() {
|
||||
return
|
||||
}
|
||||
|
||||
// 启动IP库
|
||||
remotelogs.Println("NODE", "initializing ip library ...")
|
||||
err = iplib.InitDefault()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "initialize ip library failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 检查硬盘类型
|
||||
this.checkDisk()
|
||||
|
||||
// 启动事件
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 读取API配置
|
||||
remotelogs.Println("NODE", "init config ...")
|
||||
err = this.syncConfig(0)
|
||||
@@ -146,7 +161,12 @@ func (this *Node) Start() {
|
||||
// 启动同步计时器
|
||||
this.startSyncTimer()
|
||||
|
||||
// 状态变更计时器
|
||||
// 更新IP库
|
||||
goman.New(func() {
|
||||
iplib.NewUpdater(NewIPLibraryUpdater(), 10*time.Minute).Start()
|
||||
})
|
||||
|
||||
// 监控节点运行状态
|
||||
goman.New(func() {
|
||||
NewNodeStatusExecutor().Listen()
|
||||
})
|
||||
@@ -294,7 +314,7 @@ func (this *Node) loop() error {
|
||||
}
|
||||
|
||||
var nodeCtx = rpcClient.Context()
|
||||
tasksResp, err := rpcClient.NodeTaskRPC().FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{})
|
||||
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) && !Tea.IsTesting() {
|
||||
return nil
|
||||
@@ -302,139 +322,157 @@ func (this *Node) loop() error {
|
||||
return errors.New("read node tasks failed: " + err.Error())
|
||||
}
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
switch task.Type {
|
||||
case "ipItemChanged":
|
||||
// 防止阻塞
|
||||
select {
|
||||
case iplibrary.IPListUpdateNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "configChanged":
|
||||
if task.ServerId > 0 {
|
||||
err = this.syncServerConfig(task.ServerId)
|
||||
} else {
|
||||
if !task.IsPrimary {
|
||||
// 我们等等主节点配置准备完毕
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
err = this.syncConfig(task.Version)
|
||||
}
|
||||
if err != nil {
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: false,
|
||||
Error: err.Error(),
|
||||
})
|
||||
} else {
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "nodeVersionChanged":
|
||||
if !sharedUpgradeManager.IsInstalling() {
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
}
|
||||
case "scriptsChanged":
|
||||
err = this.reloadCommonScripts()
|
||||
if err != nil {
|
||||
return errors.New("reload common scripts failed: " + err.Error())
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "nodeLevelChanged":
|
||||
levelInfoResp, err := rpcClient.NodeRPC().FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sharedNodeConfig.Level = levelInfoResp.Level
|
||||
|
||||
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
|
||||
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
|
||||
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
|
||||
if err != nil {
|
||||
return errors.New("decode level info failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "ddosProtectionChanged":
|
||||
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := this.execTask(rpcClient, nodeCtx, task)
|
||||
this.finishTask(task.Id, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, task *pb.NodeTask) error {
|
||||
switch task.Type {
|
||||
case "ipItemChanged":
|
||||
// 防止阻塞
|
||||
select {
|
||||
case iplibrary.IPListUpdateNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
case "configChanged":
|
||||
if task.ServerId > 0 {
|
||||
return this.syncServerConfig(task.ServerId)
|
||||
}
|
||||
if !task.IsPrimary {
|
||||
// 我们等等主节点配置准备完毕
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
return this.syncConfig(task.Version)
|
||||
case "nodeVersionChanged":
|
||||
if !sharedUpgradeManager.IsInstalling() {
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
}
|
||||
case "scriptsChanged":
|
||||
err := this.reloadCommonScripts()
|
||||
if err != nil {
|
||||
return errors.New("reload common scripts failed: " + err.Error())
|
||||
}
|
||||
case "nodeLevelChanged":
|
||||
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.Level = levelInfoResp.Level
|
||||
}
|
||||
|
||||
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
|
||||
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
|
||||
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
|
||||
if err != nil {
|
||||
return errors.New("decode level info failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
}
|
||||
case "ddosProtectionChanged":
|
||||
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if ddosProtectionConfig != nil && sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Warn("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
case "globalServerConfigChanged":
|
||||
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(nodeCtx, &pb.FindNodeGlobalServerConfigRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.GlobalServerConfigJSON) > 0 {
|
||||
var globalServerConfig = serverconfigs.DefaultGlobalServerConfig()
|
||||
err = json.Unmarshal(resp.GlobalServerConfigJSON, globalServerConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode global server config failed: " + err.Error())
|
||||
}
|
||||
|
||||
if globalServerConfig != nil {
|
||||
err = globalServerConfig.Init()
|
||||
if err != nil {
|
||||
return errors.New("validate global server config failed: " + err.Error())
|
||||
}
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.GlobalServerConfig = globalServerConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标记任务完成
|
||||
func (this *Node) finishTask(taskId int64, err error) {
|
||||
if taskId <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
logs.Println("[NODE]", "create rpc client failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var nodeCtx = rpcClient.Context()
|
||||
|
||||
var isOk = err == nil
|
||||
var errMsg = ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: taskId,
|
||||
IsOk: isOk,
|
||||
Error: errMsg,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// 不需要上报到服务中心
|
||||
if rpc.IsConnError(err) {
|
||||
logs.Println("[NODE]", "report task done failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE", "report task done failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 读取API配置
|
||||
func (this *Node) syncConfig(taskVersion int64) error {
|
||||
this.locker.Lock()
|
||||
@@ -463,10 +501,10 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
}
|
||||
|
||||
// 获取同步任务
|
||||
nodeCtx := rpcClient.Context()
|
||||
var nodeCtx = rpcClient.Context()
|
||||
|
||||
// TODO 这里考虑只同步版本号有变更的
|
||||
configResp, err := rpcClient.NodeRPC().FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{
|
||||
configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{
|
||||
Version: -1, // 更新所有版本
|
||||
Compress: true,
|
||||
NodeTaskVersion: taskVersion,
|
||||
@@ -557,7 +595,7 @@ func (this *Node) syncServerConfig(serverId int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.ServerRPC().ComposeServerConfig(rpcClient.Context(), &pb.ComposeServerConfigRequest{ServerId: serverId})
|
||||
resp, err := rpcClient.ServerRPC.ComposeServerConfig(rpcClient.Context(), &pb.ComposeServerConfigRequest{ServerId: serverId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -639,7 +677,7 @@ func (this *Node) checkClusterConfig() error {
|
||||
}
|
||||
|
||||
logs.Println("[NODE]registering node to cluster ...")
|
||||
resp, err := rpcClient.NodeRPC().RegisterClusterNode(rpcClient.ClusterContext(config.ClusterId, config.Secret), &pb.RegisterClusterNodeRequest{Name: HOSTNAME})
|
||||
resp, err := rpcClient.NodeRPC.RegisterClusterNode(rpcClient.ClusterContext(config.ClusterId, config.Secret), &pb.RegisterClusterNodeRequest{Name: HOSTNAME})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -721,6 +759,7 @@ func (this *Node) listenSock() error {
|
||||
|
||||
// 退出主进程
|
||||
events.Notify(events.EventQuit)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
utils.Exit()
|
||||
case "quit":
|
||||
_ = cmd.ReplyOk()
|
||||
@@ -778,13 +817,16 @@ func (this *Node) listenSock() error {
|
||||
},
|
||||
})
|
||||
case "conns":
|
||||
ipConns, serverConns := sharedClientConnLimiter.Conns()
|
||||
var addrs = []string{}
|
||||
var connMap = conns.SharedMap.AllConns()
|
||||
for _, conn := range connMap {
|
||||
addrs = append(addrs, conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]interface{}{
|
||||
"ipConns": ipConns,
|
||||
"serverConns": serverConns,
|
||||
"total": sharedListenerManager.TotalActiveConnections(),
|
||||
"addrs": addrs,
|
||||
"total": len(addrs),
|
||||
},
|
||||
})
|
||||
case "dropIP":
|
||||
@@ -856,6 +898,11 @@ func (this *Node) listenSock() error {
|
||||
} else {
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
case "bandwidth":
|
||||
var m = stats.SharedBandwidthStatManager.Map()
|
||||
_ = cmd.Reply(&gosock.Command{Params: maps.Map{
|
||||
"stats": m,
|
||||
}})
|
||||
}
|
||||
})
|
||||
|
||||
@@ -866,7 +913,7 @@ func (this *Node) listenSock() error {
|
||||
})
|
||||
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
logs.Println("NODE", "quit unix sock")
|
||||
remotelogs.Println("NODE", "quit unix sock")
|
||||
_ = this.sock.Close()
|
||||
})
|
||||
|
||||
@@ -1010,10 +1057,13 @@ func (this *Node) reloadServer() {
|
||||
}
|
||||
|
||||
func (this *Node) checkDisk() {
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
}
|
||||
for n := 'a'; n <= 'z'; n++ {
|
||||
for _, path := range []string{
|
||||
"/sys/block/vda/queue/rotational",
|
||||
"/sys/block/sda/queue/rotational",
|
||||
"/sys/block/vd" + string(n) + "/queue/rotational",
|
||||
"/sys/block/sd" + string(n) + "/queue/rotational",
|
||||
} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
@@ -1022,7 +1072,7 @@ func (this *Node) checkDisk() {
|
||||
if string(data) == "0" {
|
||||
teaconst.DiskIsFast = true
|
||||
}
|
||||
break
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func (this *NodeStatusExecutor) update() {
|
||||
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
|
||||
return
|
||||
}
|
||||
_, err = rpcClient.NodeRPC().UpdateNodeStatus(rpcClient.Context(), &pb.UpdateNodeStatusRequest{
|
||||
_, err = rpcClient.NodeRPC.UpdateNodeStatus(rpcClient.Context(), &pb.UpdateNodeStatusRequest{
|
||||
StatusJSON: jsonData,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedCNAMEManager = NewServerCNAMEManager()
|
||||
|
||||
// ServerCNAMEManager 服务CNAME管理
|
||||
// TODO 需要自动更新缓存里的记录
|
||||
type ServerCNAMEManager struct {
|
||||
ttlCache *ttlcache.Cache
|
||||
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
func NewServerCNAMEManager() *ServerCNAMEManager {
|
||||
return &ServerCNAMEManager{
|
||||
ttlCache: ttlcache.NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ServerCNAMEManager) Lookup(domain string) string {
|
||||
if len(domain) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var item = this.ttlCache.Read(domain)
|
||||
if item != nil {
|
||||
return types.String(item.Value)
|
||||
}
|
||||
|
||||
cname, _ := utils.LookupCNAME(domain)
|
||||
if len(cname) > 0 {
|
||||
cname = strings.TrimSuffix(cname, ".")
|
||||
}
|
||||
|
||||
this.ttlCache.Write(domain, cname, time.Now().Unix()+600)
|
||||
|
||||
return cname
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerCNameManager_Lookup(t *testing.T) {
|
||||
var cnameManager = NewServerCNAMEManager()
|
||||
t.Log(cnameManager.Lookup("www.yun4s.cn"))
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
t.Log(cnameManager.Lookup("www.yun4s.cn"))
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -20,15 +21,18 @@ import (
|
||||
func init() {
|
||||
var manager = NewSystemServiceManager()
|
||||
events.On(events.EventReload, func() {
|
||||
err := manager.Setup()
|
||||
if err != nil {
|
||||
remotelogs.Error("SYSTEM_SERVICE", "setup system services failed: "+err.Error())
|
||||
}
|
||||
goman.New(func() {
|
||||
err := manager.Setup()
|
||||
if err != nil {
|
||||
remotelogs.Error("SYSTEM_SERVICE", "setup system services failed: "+err.Error())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// SystemServiceManager 系统服务管理
|
||||
type SystemServiceManager struct {
|
||||
lastIsOn int // -1, 0, 1
|
||||
}
|
||||
|
||||
func NewSystemServiceManager() *SystemServiceManager {
|
||||
@@ -68,7 +72,8 @@ func (this *SystemServiceManager) setupSystemd(params maps.Map) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := &nodeconfigs.SystemdServiceConfig{}
|
||||
|
||||
var config = &nodeconfigs.SystemdServiceConfig{}
|
||||
err = json.Unmarshal(data, config)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -82,42 +87,60 @@ func (this *SystemServiceManager) setupSystemd(params maps.Map) error {
|
||||
if len(systemctl) == 0 {
|
||||
return errors.New("can not find 'systemctl' on the system")
|
||||
}
|
||||
cmd := utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// 记录上次状态
|
||||
var isOnInt int
|
||||
if config.IsOn {
|
||||
isOnInt = 1
|
||||
} else {
|
||||
isOnInt = 0
|
||||
}
|
||||
|
||||
if this.lastIsOn == isOnInt {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
this.lastIsOn = isOnInt
|
||||
}()
|
||||
|
||||
var shortName = teaconst.SystemdServiceName
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
var hasInstalled = err == nil
|
||||
if config.IsOn {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动Service
|
||||
goman.New(func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
_ = exec.Command(systemctl, "start", teaconst.SystemdServiceName).Start()
|
||||
})
|
||||
|
||||
if output == "enabled" {
|
||||
// 检查文件路径是否变化
|
||||
// 检查文件路径是否变化
|
||||
if hasInstalled && cmd.Stdout() == "enabled" {
|
||||
data, err := os.ReadFile("/etc/systemd/system/" + teaconst.SystemdServiceName + ".service")
|
||||
if err == nil && bytes.Index(data, []byte(exe)) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
|
||||
// 安装服务
|
||||
var manager = utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
err = manager.Install(exe, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动服务
|
||||
goman.New(func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
_ = executils.NewTimeoutCmd(30*time.Second, systemctl, "start", teaconst.SystemdServiceName).Start()
|
||||
})
|
||||
} else {
|
||||
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
err = manager.Uninstall()
|
||||
if err != nil {
|
||||
return err
|
||||
if hasInstalled {
|
||||
var manager = utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
err = manager.Uninstall()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ func (this *OCSPUpdateTask) Loop() error {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.SSLCertRPC().ListUpdatedSSLCertOCSP(rpcClient.Context(), &pb.ListUpdatedSSLCertOCSPRequest{
|
||||
resp, err := rpcClient.SSLCertRPC.ListUpdatedSSLCertOCSP(rpcClient.Context(), &pb.ListUpdatedSSLCertOCSPRequest{
|
||||
Version: this.version,
|
||||
Size: 100,
|
||||
})
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -81,7 +82,7 @@ func (this *SyncAPINodesTask) Loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.APINodeRPC().FindAllEnabledAPINodes(rpcClient.Context(), &pb.FindAllEnabledAPINodesRequest{})
|
||||
resp, err := rpcClient.APINodeRPC.FindAllEnabledAPINodes(rpcClient.Context(), &pb.FindAllEnabledAPINodesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -152,7 +153,7 @@ func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
|
||||
}()
|
||||
var conn *grpc.ClientConn
|
||||
if u.Scheme == "http" {
|
||||
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithInsecure(), grpc.WithBlock())
|
||||
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
|
||||
} else if u.Scheme == "https" {
|
||||
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -61,12 +62,16 @@ func (this *TOAManager) Run(config *nodeconfigs.TOAConfig) error {
|
||||
}
|
||||
remotelogs.Println("TOA", "starting ...")
|
||||
remotelogs.Println("TOA", "args: "+strings.Join(config.AsArgs(), " "))
|
||||
cmd := exec.Command(binPath, config.AsArgs()...)
|
||||
cmd := executils.NewCmd(binPath, config.AsArgs()...)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.pid = cmd.Process.Pid
|
||||
var process = cmd.Process()
|
||||
if process == nil {
|
||||
return errors.New("start failed")
|
||||
}
|
||||
this.pid = process.Pid
|
||||
|
||||
goman.New(func() {
|
||||
_ = cmd.Wait()
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
@@ -126,7 +126,7 @@ func (this *UpgradeManager) install() error {
|
||||
var sum = ""
|
||||
var filename = ""
|
||||
for {
|
||||
resp, err := client.NodeRPC().DownloadNodeInstallationFile(client.Context(), &pb.DownloadNodeInstallationFileRequest{
|
||||
resp, err := client.NodeRPC.DownloadNodeInstallationFile(client.Context(), &pb.DownloadNodeInstallationFileRequest{
|
||||
Os: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
ChunkOffset: offset,
|
||||
@@ -252,7 +252,7 @@ func (this *UpgradeManager) restart() error {
|
||||
// 启动
|
||||
exe = filepath.Dir(exe) + "/" + teaconst.ProcessName
|
||||
|
||||
var cmd = exec.Command(exe, "start")
|
||||
var cmd = executils.NewCmd(exe, "start")
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
)
|
||||
|
||||
var prefixReg = regexp.MustCompile(`^\(\?([\w\s]+)\)`) // (?x)
|
||||
var prefixReg2 = regexp.MustCompile(`^\(\?([\w\s]*:)`) // (?x: ...
|
||||
var braceZeroReg = regexp.MustCompile(`^{\s*0*\s*}`) // {0}
|
||||
var braceZeroReg2 = regexp.MustCompile(`^{\s*0*\s*,`) // {0, x}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ var logChan = make(chan *pb.NodeLog, 1024)
|
||||
|
||||
func init() {
|
||||
// 定期上传日志
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
var ticker = time.NewTicker(60 * time.Second)
|
||||
if Tea.IsTesting() {
|
||||
ticker = time.NewTicker(10 * time.Second)
|
||||
}
|
||||
@@ -37,6 +37,11 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
// Debug 打印调试信息
|
||||
func Debug(tag string, description string) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
}
|
||||
|
||||
// Println 打印普通信息
|
||||
func Println(tag string, description string) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
@@ -73,6 +78,31 @@ func Warn(tag string, description string) {
|
||||
}
|
||||
}
|
||||
|
||||
// WarnServer 打印服务相关警告
|
||||
func WarnServer(tag string, description string) {
|
||||
if Tea.IsTesting() {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
}
|
||||
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case logChan <- &pb.NodeLog{
|
||||
Role: teaconst.Role,
|
||||
Tag: tag,
|
||||
Description: description,
|
||||
Level: "warning",
|
||||
NodeId: teaconst.NodeId,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Error 打印错误信息
|
||||
func Error(tag string, description string) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
@@ -97,6 +127,37 @@ func Error(tag string, description string) {
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorServer 打印服务相关错误信息
|
||||
func ErrorServer(tag string, description string) {
|
||||
if Tea.IsTesting() {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
}
|
||||
|
||||
// 忽略RPC连接错误
|
||||
var level = "error"
|
||||
if strings.Contains(description, "code = Unavailable desc") {
|
||||
level = "warning"
|
||||
}
|
||||
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case logChan <- &pb.NodeLog{
|
||||
Role: teaconst.Role,
|
||||
Tag: tag,
|
||||
Description: description,
|
||||
Level: level,
|
||||
NodeId: teaconst.NodeId,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorObject 打印错误对象
|
||||
func ErrorObject(tag string, err error) {
|
||||
if err == nil {
|
||||
@@ -111,7 +172,15 @@ func ErrorObject(tag string, err error) {
|
||||
|
||||
// ServerError 打印服务相关错误信息
|
||||
func ServerError(serverId int64, tag string, description string, logType nodeconfigs.NodeLogType, params maps.Map) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
if Tea.IsTesting() {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
}
|
||||
|
||||
// 是否记录服务相关错误
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
|
||||
return
|
||||
}
|
||||
|
||||
// 参数
|
||||
var paramsJSON []byte
|
||||
@@ -173,7 +242,6 @@ func ServerSuccess(serverId int64, tag string, description string, logType nodec
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ServerLog 打印服务相关日志信息
|
||||
func ServerLog(serverId int64, tag string, description string, logType nodeconfigs.NodeLogType, params maps.Map) {
|
||||
logs.Println("[" + tag + "]" + description)
|
||||
@@ -208,7 +276,7 @@ func ServerLog(serverId int64, tag string, description string, logType nodeconfi
|
||||
|
||||
// 上传日志
|
||||
func uploadLogs() error {
|
||||
logList := []*pb.NodeLog{}
|
||||
var logList = []*pb.NodeLog{}
|
||||
|
||||
const hashSize = 5
|
||||
var hashList = []uint64{}
|
||||
@@ -243,6 +311,7 @@ Loop:
|
||||
if len(logList) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -253,6 +322,6 @@ Loop:
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeLogRPC().CreateNodeLogs(rpcClient.Context(), &pb.CreateNodeLogsRequest{NodeLogs: logList})
|
||||
_, err = rpcClient.NodeLogRPC.CreateNodeLogs(rpcClient.Context(), &pb.CreateNodeLogsRequest{NodeLogs: logList})
|
||||
return err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user