Compare commits

...

32 Commits

Author SHA1 Message Date
刘祥超
8bec1cf68e 优化时钟同步相关代码 2022-09-25 14:26:46 +08:00
刘祥超
2cd1bb7f95 时钟同步错误提示改成警告 2022-09-25 09:40:28 +08:00
刘祥超
19e6329a2b 部分提示支持繁体中文 2022-09-24 20:02:29 +08:00
刘祥超
fce2879567 Websocket支持自定义响应Header 2022-09-23 14:21:53 +08:00
刘祥超
0973765919 增加防盗链拦截时默认提示文字 2022-09-22 17:52:57 +08:00
刘祥超
827679721e 增加防盗链功能 2022-09-22 16:33:53 +08:00
刘祥超
735279bc7a 删除的IP名单不再写入到本地数据库 2022-09-21 17:06:25 +08:00
刘祥超
3eb2ed9897 IP名单数据库增加定时清理 2022-09-21 16:49:48 +08:00
刘祥超
3a913d98c7 优化服务相关错误和警告日志 2022-09-20 14:58:55 +08:00
刘祥超
9bfcd79e36 缩小日志文件尺寸 2022-09-18 16:42:11 +08:00
刘祥超
a81d610302 优化代码 2022-09-18 16:18:31 +08:00
刘祥超
64b1753c4d 自动尝试安装nftables 2022-09-17 21:08:56 +08:00
刘祥超
afcb5c2957 集群设置中增加服务设置/可以设置不记录服务错误日志 2022-09-16 18:41:47 +08:00
刘祥超
7d0b9208a3 访问日志中增加源站状态码 2022-09-16 10:07:40 +08:00
刘祥超
c0f0ec43bb Websocket也支持失败自动重试 2022-09-16 09:37:49 +08:00
刘祥超
30bd66958c 优化NTP时钟自动同步 2022-09-15 15:59:29 +08:00
刘祥超
fafac1a038 优化系统服务管理 2022-09-15 15:01:19 +08:00
刘祥超
f979f9503e 优化服务安装相关代码 2022-09-15 11:45:29 +08:00
刘祥超
b233c3cc7a 优化命令执行相关代码 2022-09-15 11:14:33 +08:00
刘祥超
597ac936f7 检查synflood时忽略IP白名单和局域网连接 2022-09-14 18:52:26 +08:00
刘祥超
a590254eb3 修复有多个网络出口时,可能无法正确转发UDP消息的问题 2022-09-14 17:18:00 +08:00
刘祥超
0498bcf30c 调整GRPC参数 2022-09-12 21:59:52 +08:00
刘祥超
59f9b5c724 优化netdns设置 2022-09-12 17:43:48 +08:00
刘祥超
80729935b6 优化代码 2022-09-11 19:03:53 +08:00
刘祥超
4ca57fb99c 重启时保留-shm,-wal文件 2022-09-09 09:34:00 +08:00
刘祥超
9b35902ad4 访问日志因尺寸过大无法提交到API节点时,自动去除requestBody后再次尝试 2022-09-07 17:34:57 +08:00
刘祥超
3b8bd09190 优化代码 2022-09-07 14:54:36 +08:00
刘祥超
71a5bc0652 可以使用EdgeRecover环境变量指示恢复数据库 2022-09-07 14:44:36 +08:00
刘祥超
ac6a8c4e85 启动时尝试自动修复损坏的缓存索引数据库 2022-09-07 13:55:36 +08:00
刘祥超
f58a808c3a 改进缓存LFU算法 2022-09-07 11:34:26 +08:00
刘祥超
51037be772 将版本修改为0.5.3 2022-09-05 11:04:46 +08:00
刘祥超
443ff9aff7 改进Captcha认证计数 2022-09-05 10:59:02 +08:00
79 changed files with 1636 additions and 780 deletions

View File

@@ -41,7 +41,7 @@ func (this *LogWriter) Init() {
this.c = make(chan string, 1024)
// 异步写入文件
var maxFileSize = 2 * sizes.G // 文件最大尺寸,超出此尺寸则清空
var maxFileSize = 128 * sizes.M // 文件最大尺寸,超出此尺寸则清空
if fp != nil {
goman.New(func() {
var totalSize int64 = 0

View File

@@ -10,8 +10,10 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"os"
"runtime"
"strings"
"time"
@@ -48,16 +50,16 @@ type FileListDB struct {
deleteByHashStmt *dbs.Stmt // 根据hash删除数据
deleteByHashSQL string
statStmt *dbs.Stmt // 统计
purgeStmt *dbs.Stmt // 清理
deleteAllStmt *dbs.Stmt // 删除所有数据
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
statStmt *dbs.Stmt // 统计
purgeStmt *dbs.Stmt // 清理
deleteAllStmt *dbs.Stmt // 删除所有数据
listOlderItemsStmt *dbs.Stmt // 读取较早存储的缓存
updateAccessWeekSQL string // 修改访问日期
// hits
insertHitSQL string // 写入数据
increaseHitSQL string // 增加点击量
deleteHitByHashSQL string // 根据hash删除数据
lfuHitsStmt *dbs.Stmt // 读取老的数据
insertHitSQL string // 写入数据
increaseHitSQL string // 增加点击量
deleteHitByHashSQL string // 根据hash删除数据
}
func NewFileListDB() *FileListDB {
@@ -84,14 +86,26 @@ func (this *FileListDB) Open(dbPath string) error {
writeDB.SetMaxOpenConns(1)
this.writeDB = dbs.NewDB(writeDB)
// TODO 耗时过长,暂时不整理数据库
// TODO 需要根据行数来判断是否VACUUM
// TODO 注意VACUUM反而可能让数据库文件变大
/**_, err = db.Exec("VACUUM")
if err != nil {
return err
}**/
this.writeDB = dbs.NewDB(writeDB)
// 检查是否损坏
// TODO 暂时屏蔽,因为用时过长
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 && this.shouldRecover() {
for _, indexName := range []string{"staleAt", "hash"} {
_, _ = this.writeDB.Exec(`REINDEX "` + indexName + `"`)
}
}
this.writeBatch = dbs.NewBatch(writeDB, 4)
this.writeBatch.OnFail(func(err error) {
remotelogs.Warn("LIST_FILE_DB", "run batch failed: "+err.Error())
@@ -151,7 +165,7 @@ func (this *FileListDB) Init() error {
return err
}
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
this.insertSQL = `INSERT INTO "` + this.itemsTableName + `" ("hash", "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt", "accessWeek") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
this.insertStmt, err = this.writeDB.Prepare(this.insertSQL)
if err != nil {
return err
@@ -185,7 +199,12 @@ func (this *FileListDB) Init() error {
return err
}
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "id" ASC LIMIT ?`)
this.listOlderItemsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.itemsTableName + `" ORDER BY "accessWeek" ASC, "id" ASC LIMIT ?`)
if err != nil {
return err
}
this.updateAccessWeekSQL = `UPDATE "` + this.itemsTableName + `" SET "accessWeek"=? WHERE "hash"=?`
this.insertHitSQL = `INSERT INTO "` + this.hitsTableName + `" ("hash", "week2Hits", "week") VALUES (?, 1, ?)`
@@ -193,11 +212,6 @@ func (this *FileListDB) Init() error {
this.deleteHitByHashSQL = `DELETE FROM "` + this.hitsTableName + `" WHERE "hash"=?`
this.lfuHitsStmt, err = this.readDB.Prepare(`SELECT "hash" FROM "` + this.hitsTableName + `" ORDER BY "week" ASC, "week1Hits"+"week2Hits" ASC LIMIT ?`)
if err != nil {
return err
}
this.isReady = true
// 加载HashMap
@@ -226,7 +240,7 @@ func (this *FileListDB) AddAsync(hash string, item *Item) error {
item.StaleAt = item.ExpiredAt
}
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime())
this.writeBatch.Add(this.insertSQL, hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
return nil
}
@@ -238,7 +252,7 @@ func (this *FileListDB) AddSync(hash string, item *Item) error {
item.StaleAt = item.ExpiredAt
}
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime())
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime(), timeutil.Format("YW"))
if err != nil {
return this.WrapError(err)
}
@@ -300,22 +314,23 @@ func (this *FileListDB) ListLFUItems(count int) (hashList []string, err error) {
count = 100
}
hashList, err = this.listLFUItems(count)
// 先找过期的
hashList, err = this.ListExpiredItems(count)
if err != nil {
return
}
var l = len(hashList)
if len(hashList) > count/2 {
return
// 从旧缓存中补充
if l < count {
oldHashList, err := this.listOlderItems(count - l)
if err != nil {
return nil, err
}
hashList = append(hashList, oldHashList...)
}
// 不足补齐
olderHashList, err := this.listOlderItems(count - len(hashList))
if err != nil {
return nil, err
}
hashList = append(hashList, olderHashList...)
return
return hashList, nil
}
func (this *FileListDB) ListHashes(lastId int64) (hashList []string, maxId int64, err error) {
@@ -342,6 +357,7 @@ func (this *FileListDB) ListHashes(lastId int64) (hashList []string, maxId int64
func (this *FileListDB) IncreaseHitAsync(hash string) error {
var week = timeutil.Format("YW")
this.writeBatch.Add(this.increaseHitSQL, hash, week, week, week, week)
this.writeBatch.Add(this.updateAccessWeekSQL, week, hash)
return nil
}
@@ -418,8 +434,9 @@ func (this *FileListDB) Close() error {
if this.listOlderItemsStmt != nil {
_ = this.listOlderItemsStmt.Close()
}
if this.lfuHitsStmt != nil {
_ = this.lfuHitsStmt.Close()
if this.writeBatch != nil {
this.writeBatch.Close()
}
var errStrings []string
@@ -438,10 +455,6 @@ func (this *FileListDB) Close() error {
}
}
if this.writeBatch != nil {
this.writeBatch.Close()
}
if len(errStrings) == 0 {
return nil
}
@@ -472,7 +485,8 @@ func (this *FileListDB) initTables(times int) error {
"staleAt" integer DEFAULT 0,
"createdAt" integer DEFAULT 0,
"host" varchar(128),
"serverId" integer
"serverId" integer,
"accessWeek" varchar(6)
);
DROP INDEX IF EXISTS "createdAt";
@@ -488,19 +502,28 @@ CREATE UNIQUE INDEX IF NOT EXISTS "hash"
ON "` + this.itemsTableName + `" (
"hash" ASC
);
ALTER TABLE "cacheItems" ADD "accessWeek" varchar(6);
`)
if err != nil {
// 尝试删除重建
if times < 3 {
_, dropErr := this.writeDB.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
if dropErr == nil {
return this.initTables(times + 1)
}
return this.WrapError(err)
// 忽略可以预期的错误
if strings.Contains(err.Error(), "duplicate column name") {
err = nil
}
return this.WrapError(err)
// 尝试删除重建
if err != nil {
if times < 3 {
_, dropErr := this.writeDB.Exec(`DROP TABLE "` + this.itemsTableName + `"`)
if dropErr == nil {
return this.initTables(times + 1)
}
return this.WrapError(err)
}
return this.WrapError(err)
}
}
}
@@ -535,27 +558,6 @@ ON "` + this.hitsTableName + `" (
return nil
}
func (this *FileListDB) listLFUItems(count int) (hashList []string, err error) {
rows, err := this.lfuHitsStmt.Query(count)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var hash string
err = rows.Scan(&hash)
if err != nil {
return nil, err
}
hashList = append(hashList, hash)
}
return hashList, nil
}
func (this *FileListDB) listOlderItems(count int) (hashList []string, err error) {
rows, err := this.listOlderItemsStmt.Query(count)
if err != nil {
@@ -576,3 +578,21 @@ func (this *FileListDB) listOlderItems(count int) (hashList []string, err error)
return hashList, nil
}
func (this *FileListDB) shouldRecover() bool {
result, err := this.writeDB.Query("pragma integrity_check;")
if err != nil {
logs.Println(result)
}
var errString = ""
var shouldRecover = false
for result.Next() {
err = result.Scan(&errString)
if strings.TrimSpace(errString) != "ok" {
shouldRecover = true
}
break
}
_ = result.Close()
return shouldRecover
}

View File

@@ -0,0 +1,49 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package caches_test
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
"time"
)
func TestFileListDB_ListLFUItems(t *testing.T) {
var db = caches.NewFileListDB()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
//err := db.Open(Tea.Root + "/data/cache-index/p1/db-0.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
hashList, err := db.ListLFUItems(100)
if err != nil {
t.Fatal(err)
}
t.Log("[", len(hashList), "]", hashList)
}
func TestFileListDB_IncreaseHitAsync(t *testing.T) {
var db = caches.NewFileListDB()
err := db.Open(Tea.Root + "/data/cache-db-large.db")
if err != nil {
t.Fatal(err)
}
err = db.Init()
err = db.IncreaseHitAsync("4598e5231ba47d6ec7aa9ea640ff2eaf")
if err != nil {
t.Fatal(err)
}
// wait transaction
time.Sleep(1 * time.Second)
}

View File

@@ -107,6 +107,12 @@ func (this *FileListHashMap) IsReady() bool {
return this.isReady
}
func (this *FileListHashMap) Len() int {
this.locker.Lock()
defer this.locker.Unlock()
return len(this.m)
}
func (this *FileListHashMap) bigInt(hash string) uint64 {
var bigInt = big.NewInt(0)
bigInt.SetString(hash, 16)

View File

@@ -12,6 +12,7 @@ import (
"runtime"
"strconv"
"testing"
"time"
)
func TestFileListHashMap_Memory(t *testing.T) {
@@ -68,10 +69,14 @@ func TestFileListHashMap_Load(t *testing.T) {
}()
var m = caches.NewFileListHashMap()
err = m.Load(list.GetDB("abc"))
var before = time.Now()
var db = list.GetDB("abc")
err = m.Load(db)
if err != nil {
t.Fatal(err)
}
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Log("count:", m.Len())
m.Add("abc")
for _, hash := range []string{"33347bb4441265405347816cad36a0f8", "a", "abc", "123"} {

View File

@@ -179,7 +179,7 @@ func (this *FileStorage) UpdatePolicy(newPolicy *serverconfigs.HTTPCachePolicy)
// open cache
oldOpenFileCacheJSON, _ := json.Marshal(oldOpenFileCache)
newOpenFileCacheJSON, _ := json.Marshal(this.options.OpenFileCache)
if bytes.Compare(oldOpenFileCacheJSON, newOpenFileCacheJSON) != 0 {
if !bytes.Equal(oldOpenFileCacheJSON, newOpenFileCacheJSON) {
this.initOpenFileCache()
}

View File

@@ -520,8 +520,6 @@ func (this *MemoryStorage) flushItem(key string) {
// 从内存中移除
_ = this.Delete(key)
return
}
func (this *MemoryStorage) memoryCapacityBytes() int64 {

View File

@@ -1,5 +0,0 @@
package configs
import "sync"
var sharedLocker = &sync.RWMutex{}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.5.2"
Version = "0.5.3"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package firewalls
@@ -14,6 +13,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
@@ -21,6 +21,7 @@ import (
"net"
"os/exec"
"strings"
"time"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
@@ -53,19 +54,13 @@ func init() {
// DDoSProtectionManager DDoS防护
type DDoSProtectionManager struct {
nftPath string
lastAllowIPList []string
lastConfig []byte
}
// NewDDoSProtectionManager 获取新对象
func NewDDoSProtectionManager() *DDoSProtectionManager {
nftPath, _ := exec.LookPath("nft")
return &DDoSProtectionManager{
nftPath: nftPath,
}
return &DDoSProtectionManager{}
}
// Apply 应用配置
@@ -91,7 +86,7 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
}
remotelogs.Println("FIREWALL", "change DDoS protection config")
if len(this.nftPath) == 0 {
if len(this.nftExe()) == 0 {
return errors.New("can not find nft command")
}
@@ -154,6 +149,11 @@ func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) e
// 添加TCP规则
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
var nftExe = this.nftExe()
if len(nftExe) == 0 {
return nil
}
// 检查nft版本不能小于0.9
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
return nil
@@ -263,23 +263,21 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
// 添加新规则
for _, port := range ports {
if maxConnections > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
}
// TODO 让用户选择是drop还是reject
if maxConnectionsPerIP > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
}
@@ -287,20 +285,18 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
// TODO 让用户选择是drop还是reject
if newConnectionsMinutelyRate > 0 {
if newConnectionsMinutelyRateBlockTimeout > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
} else {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
}
}
@@ -309,20 +305,18 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
// TODO 让用户选择是drop还是reject
if newConnectionsSecondlyRate > 0 {
if newConnectionsSecondlyRateBlockTimeout > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
} else {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(10*time.Second, nftExe, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + cmd.Stderr() + ")")
}
}
}
@@ -555,3 +549,8 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
return nil
}
func (this *DDoSProtectionManager) nftExe() string {
path, _ := exec.LookPath("nft")
return path
}

View File

@@ -11,7 +11,6 @@ import (
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
type DDoSProtectionManager struct {
nftPath string
}
func NewDDoSProtectionManager() *DDoSProtectionManager {

View File

@@ -7,13 +7,15 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/conns"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/types"
"os/exec"
"strings"
"time"
)
type firewalldCmd struct {
cmd *exec.Cmd
cmd *executils.Cmd
denyIP string
}
@@ -32,7 +34,7 @@ func NewFirewalld() *Firewalld {
path, err := exec.LookPath("firewall-cmd")
if err == nil && len(path) > 0 {
var cmd = exec.Command(path, "--state")
var cmd = executils.NewTimeoutCmd(3*time.Second, path, "--state")
err := cmd.Run()
if err == nil {
firewalld.exe = path
@@ -85,7 +87,7 @@ func (this *Firewalld) AllowPort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd, "")
return nil
}
@@ -95,12 +97,12 @@ func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol
var port = this.PortRangeString(portRange, protocol)
{
var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+port, "--permanent")
this.pushCmd(cmd, "")
}
{
var cmd = exec.Command(this.exe, "--add-port="+port)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--add-port="+port)
this.pushCmd(cmd, "")
}
}
@@ -112,7 +114,7 @@ func (this *Firewalld) RemovePort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd, "")
return nil
}
@@ -121,12 +123,12 @@ func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol str
var port = this.PortRangeString(portRange, protocol)
{
var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+port, "--permanent")
this.pushCmd(cmd, "")
}
{
var cmd = exec.Command(this.exe, "--remove-port="+port)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, "--remove-port="+port)
this.pushCmd(cmd, "")
}
@@ -159,7 +161,7 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
this.pushCmd(cmd, ip)
return nil
}
@@ -182,7 +184,7 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
if async {
this.pushCmd(cmd, ip)
return nil
@@ -209,13 +211,13 @@ func (this *Firewalld) RemoveSourceIP(ip string) error {
}
for _, action := range []string{"reject", "drop"} {
var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
var cmd = exec.Command(this.exe, args...)
var cmd = executils.NewTimeoutCmd(10*time.Second, this.exe, args...)
this.pushCmd(cmd, "")
}
return nil
}
func (this *Firewalld) pushCmd(cmd *exec.Cmd, denyIP string) {
func (this *Firewalld) pushCmd(cmd *executils.Cmd, denyIP string) {
select {
case this.cmdQueue <- &firewalldCmd{cmd: cmd, denyIP: denyIP}:
default:

View File

@@ -5,7 +5,6 @@
package firewalls
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/conns"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
@@ -13,6 +12,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/types"
"net"
"os/exec"
@@ -432,15 +432,14 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
// 读取版本号
func (this *NFTablesFirewall) readVersion(nftPath string) string {
var cmd = exec.Command(nftPath, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
var cmd = executils.NewTimeoutCmd(10*time.Second, nftPath, "--version")
cmd.WithStdout()
err := cmd.Run()
if err != nil {
return ""
}
var outputString = output.String()
var outputString = cmd.Stdout()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
return ""

View File

@@ -1,6 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables

View File

@@ -0,0 +1,108 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nftables
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/logs"
"os"
"os/exec"
"runtime"
"time"
)
func init() {
events.On(events.EventReload, func() {
// linux only
if runtime.GOOS != "linux" {
return
}
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil {
return
}
if nodeConfig == nil || !nodeConfig.AutoInstallNftables {
return
}
if os.Getgid() == 0 { // root user only
_, err := exec.LookPath("nft")
if err == nil {
return
}
goman.New(func() {
err := NewInstaller().Install()
if err != nil {
// 不需要传到API节点
logs.Println("[NFTABLES]install nftables failed: " + err.Error())
}
})
}
})
}
type Installer struct {
}
func NewInstaller() *Installer {
return &Installer{}
}
func (this *Installer) Install() error {
// linux only
if runtime.GOOS != "linux" {
return nil
}
// 检查是否已经存在
_, err := exec.LookPath("nft")
if err == nil {
return nil
}
var cmd *executils.Cmd
// check dnf
dnfExe, err := exec.LookPath("dnf")
if err == nil {
cmd = executils.NewCmd(dnfExe, "-y", "install", "nftables")
}
// check apt
if cmd == nil {
aptExe, err := exec.LookPath("apt")
if err == nil {
cmd = executils.NewCmd(aptExe, "install", "nftables")
}
}
// check yum
if cmd == nil {
yumExe, err := exec.LookPath("yum")
if err == nil {
cmd = executils.NewCmd(yumExe, "-y", "install", "nftables")
}
}
if cmd == nil {
return nil
}
cmd.WithTimeout(10 * time.Minute)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
return errors.New(err.Error() + ": " + cmd.Stderr())
}
remotelogs.Println("NFTABLES", "installed nftables with command '"+cmd.String()+"' successfully")
return nil
}

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables

View File

@@ -1,4 +1,5 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables_test

View File

@@ -1,11 +1,11 @@
package iplibrary
import (
"bytes"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"os/exec"
"runtime"
"time"
@@ -13,9 +13,9 @@ import (
// FirewalldAction Firewalld动作管理
// 常用命令:
// - 查询列表: firewall-cmd --list-all
// - 添加IPfirewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
// - 删除IPfirewall-cmd --remove-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
// - 查询列表: firewall-cmd --list-all
// - 添加IPfirewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
// - 删除IPfirewall-cmd --remove-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
type FirewalldAction struct {
BaseAction
@@ -144,12 +144,11 @@ func (this *FirewalldAction) runActionSingleIP(action string, listType IPListTyp
// MAC OS直接返回
return nil
}
cmd := exec.Command(path, args...)
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
cmd := executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
return errors.New(err.Error() + ", output: " + string(stderr.Bytes()))
return errors.New(err.Error() + ", output: " + cmd.Stderr())
}
return nil
}

View File

@@ -1,10 +1,10 @@
package iplibrary
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/types"
"os/exec"
"runtime"
@@ -16,12 +16,12 @@ import (
// IPSetAction IPSet动作
// 相关命令:
// - 利用Firewalld管理set
// - 添加firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
// - 删除firewall-cmd --permanent --delete-ipset=edge_ip_list
// - 重载firewall-cmd --reload
// - firewalld+ipset: firewall-cmd --permanent --add-rich-rule="rule source ipset='edge_ip_list' reject"
// - 添加firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
// - 删除firewall-cmd --permanent --delete-ipset=edge_ip_list
// - 重载firewall-cmd --reload
// - firewalld+ipset: firewall-cmd --permanent --add-rich-rule="rule source ipset='edge_ip_list' reject"
// - 利用IPTables管理set
// - 添加iptables -A INPUT -m set --match-set edge_ip_list src -j REJECT
// - 添加iptables -A INPUT -m set --match-set edge_ip_list src -j REJECT
// - 添加Itemipset add edge_ip_list 192.168.2.32 timeout 30
// - 删除Item: ipset del edge_ip_list 192.168.2.32
// - 创建setipset create edge_ip_list hash:ip timeout 0
@@ -30,16 +30,13 @@ import (
type IPSetAction struct {
BaseAction
config *firewallconfigs.FirewallActionIPSetConfig
errorBuf *bytes.Buffer
config *firewallconfigs.FirewallActionIPSetConfig
ipsetNotfound bool
}
func NewIPSetAction() *IPSetAction {
return &IPSetAction{
errorBuf: &bytes.Buffer{},
}
return &IPSetAction{}
}
func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) error {
@@ -68,14 +65,13 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
var cmd = exec.Command(path, "create", listName, "hash:ip", "timeout", "0", "maxelem", "1000000")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "timeout", "0", "maxelem", "1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
if !bytes.Contains(output, []byte("already exists")) {
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
var output = cmd.Stderr()
if !strings.Contains(output, "already exists") {
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + output)
} else {
err = nil
}
@@ -87,14 +83,13 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
var cmd = exec.Command(path, "create", listName, "hash:ip", "family", "inet6", "timeout", "0", "maxelem", "1000000")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "family", "inet6", "timeout", "0", "maxelem", "1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
if !bytes.Contains(output, []byte("already exists")) {
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
var output = cmd.Stderr()
if !strings.Contains(output, "already exists") {
return errors.New("create ipset '" + listName + "': " + err.Error() + ", output: " + output)
} else {
err = nil
}
@@ -114,16 +109,15 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
cmd := exec.Command(path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=timeout=0", "--option=maxelem=1000000")
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=timeout=0", "--option=maxelem=1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
output := stderr.Bytes()
if bytes.Contains(output, []byte("NAME_CONFLICT")) {
var output = cmd.Stderr()
if strings.Contains(output, "NAME_CONFLICT") {
err = nil
} else {
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + output)
}
}
}
@@ -133,16 +127,15 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
cmd := exec.Command(path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=family=inet6", "--option=timeout=0", "--option=maxelem=1000000")
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=family=inet6", "--option=timeout=0", "--option=maxelem=1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
if bytes.Contains(output, []byte("NAME_CONFLICT")) {
var output = cmd.Stderr()
if strings.Contains(output, "NAME_CONFLICT") {
err = nil
} else {
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + string(output))
return errors.New("firewall-cmd add ipset '" + listName + "': " + err.Error() + ", output: " + output)
}
}
}
@@ -152,13 +145,11 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
var cmd = exec.Command(path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' accept")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' accept")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + string(output))
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + cmd.Stderr())
}
}
@@ -167,25 +158,21 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
if len(listName) == 0 {
continue
}
var cmd = exec.Command(path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' reject")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' reject")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + string(output))
return errors.New("firewall-cmd add rich rule '" + listName + "': " + err.Error() + ", output: " + cmd.Stderr())
}
}
// reload
{
cmd := exec.Command(path, "--reload")
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--reload")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
return errors.New("firewall-cmd reload: " + err.Error() + ", output: " + string(output))
return errors.New("firewall-cmd reload: " + err.Error() + ", output: " + cmd.Stderr())
}
}
}
@@ -204,19 +191,17 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
}
// 检查规则是否存在
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
err := cmd.Run()
var exists = err == nil
// 添加规则
if !exists {
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
return errors.New("iptables add rule: " + err.Error() + ", output: " + cmd.Stderr())
}
}
}
@@ -228,18 +213,16 @@ func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) erro
}
// 检查规则是否存在
var cmd = exec.Command(path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
err := cmd.Run()
var exists = err == nil
if !exists {
var cmd = exec.Command(path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
var stderr = bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = stderr.Bytes()
return errors.New("iptables add rule: " + err.Error() + ", output: " + string(output))
return errors.New("iptables add rule: " + err.Error() + ", output: " + cmd.Stderr())
}
}
}
@@ -361,12 +344,11 @@ func (this *IPSetAction) runActionSingleIP(action string, listType IPListType, i
return nil
}
this.errorBuf.Reset()
var cmd = exec.Command(path, args...)
cmd.Stderr = this.errorBuf
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
var errString = this.errorBuf.String()
var errString = cmd.Stderr()
if action == "deleteItem" && strings.Contains(errString, "not added") {
return nil
}

View File

@@ -1,20 +1,23 @@
package iplibrary
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"os/exec"
"runtime"
"strings"
"time"
)
// IPTablesAction IPTables动作
// 相关命令:
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
// iptables -A INPUT -s "192.168.2.32" -j REJECT
// iptables -D INPUT ...
// iptables -F INPUT
//
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
// iptables -A INPUT -s "192.168.2.32" -j REJECT
// iptables -D INPUT ...
// iptables -F INPUT
type IPTablesAction struct {
BaseAction
@@ -110,16 +113,15 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
return nil
}
cmd := exec.Command(path, args...)
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
output := stderr.Bytes()
if bytes.Contains(output, []byte("No chain/target/match")) {
var output = cmd.Stderr()
if strings.Contains(output, "No chain/target/match") {
err = nil
} else {
return errors.New(err.Error() + ", output: " + string(output))
return errors.New(err.Error() + ", output: " + output)
}
}
return nil

View File

@@ -68,7 +68,7 @@ func (this *ActionManager) UpdateActions(actions []*firewallconfigs.FirewallActi
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
continue
}
if bytes.Compare(newConfigJSON, oldConfigJSON) != 0 {
if !bytes.Equal(newConfigJSON, oldConfigJSON) {
_ = oldInstance.Close()
// 重新创建

View File

@@ -1,13 +1,13 @@
package iplibrary
import (
"bytes"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"os/exec"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"path/filepath"
"time"
)
// ScriptAction 脚本命令动作
@@ -45,25 +45,24 @@ func (this *ScriptAction) DeleteItem(listType IPListType, item *pb.IPItem) error
func (this *ScriptAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
// TODO 智能支持 .sh 脚本文件
cmd := exec.Command(this.config.Path)
cmd.Env = []string{
var cmd = executils.NewTimeoutCmd(30*time.Second, this.config.Path)
cmd.WithEnv([]string{
"ACTION=" + action,
"TYPE=" + item.Type,
"IP_FROM=" + item.IpFrom,
"IP_TO=" + item.IpTo,
"EXPIRED_AT=" + fmt.Sprintf("%d", item.ExpiredAt),
"LIST_TYPE=" + listType,
}
})
if len(this.config.Cwd) > 0 {
cmd.Dir = this.config.Cwd
cmd.WithDir(this.config.Cwd)
} else {
cmd.Dir = filepath.Dir(this.config.Path)
cmd.WithDir(filepath.Dir(this.config.Path))
}
stderr := bytes.NewBuffer([]byte{})
cmd.Stderr = stderr
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New(err.Error() + ", output: " + string(stderr.Bytes()))
return errors.New(err.Error() + ", output: " + cmd.Stderr())
}
return nil
}

View File

@@ -55,8 +55,6 @@ func (this *IPListDB) init() error {
}
var path = this.dir + "/ip_list.db"
_ = os.Remove(path + "-shm")
_ = os.Remove(path + "-wal")
db, err := sql.Open("sqlite3", "file:"+path+"?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
if err != nil {
@@ -71,6 +69,14 @@ func (this *IPListDB) init() error {
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
@@ -164,6 +170,12 @@ func (this *IPListDB) AddItem(item *pb.IPItem) error {
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return nil
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
return err
}

View File

@@ -2,6 +2,7 @@ package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -18,6 +19,10 @@ var SharedIPListManager = NewIPListManager()
var IPListUpdateNotify = make(chan bool, 1)
func init() {
if teaconst.IsDaemon {
return
}
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedIPListManager.Start()
@@ -26,6 +31,13 @@ func init() {
events.On(events.EventQuit, func() {
SharedIPListManager.Stop()
})
var ticker = time.NewTicker(24 * time.Hour)
goman.New(func() {
for range ticker.C {
SharedIPListManager.DeleteExpiredItems()
}
})
}
// IPListManager IP名单管理
@@ -186,6 +198,12 @@ func (this *IPListManager) FindList(listId int64) *IPList {
return list
}
func (this *IPListManager) DeleteExpiredItems() {
if this.db != nil {
_ = this.db.DeleteExpiredItems()
}
}
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
var changedLists = map[*IPList]zero.Zero{}
for _, item := range items {

View File

@@ -91,8 +91,6 @@ func (this *Task) Init() error {
}
var path = dir + "/metric." + types.String(this.item.Id) + ".db"
_ = os.Remove(path + "-shm")
_ = os.Remove(path + "-wal")
db, err := sql.Open("sqlite3", "file:"+path+"?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
if err != nil {
@@ -101,6 +99,14 @@ func (this *Task) Init() error {
db.SetMaxOpenConns(1)
this.db = dbs.NewDB(db)
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"serverId", "hash"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
if teaconst.EnableDBStat {
this.db.EnableStat(true)
}

View File

@@ -1,7 +1,6 @@
package nodes
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -17,7 +16,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"net/url"
@@ -347,15 +346,15 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
return nil
}
var cmd = utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run()
var shortName = teaconst.SystemdServiceName
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
cmd.WithStdout()
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
return nil
}
if output == "enabled" {
if cmd.Stdout() == "enabled" {
this.replyOk(message.RequestId, "ok")
} else {
this.replyFail(message.RequestId, "not installed")
@@ -385,16 +384,15 @@ func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) e
return nil
}
var cmd = exec.Command(nft, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version")
cmd.WithStdout()
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = output.String()
var outputString = cmd.Stdout()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")

View File

@@ -17,7 +17,6 @@ import (
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
@@ -26,18 +25,17 @@ import (
type ClientConn struct {
BaseClientConn
once sync.Once
isTLS bool
hasDeadline bool
hasRead bool
isLO bool // 是否为环路
isLO bool // 是否为环路
isInAllowList bool
hasResetSYNFlood bool
}
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool) net.Conn {
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
// 是否为环路
var remoteAddr = rawConn.RemoteAddr().String()
var isLO = strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:")
@@ -46,6 +44,7 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool) net.Conn {
BaseClientConn: BaseClientConn{rawConn: rawConn},
isTLS: isTLS,
isLO: isLO,
isInAllowList: isInAllowList,
}
if quickClose {
@@ -89,20 +88,24 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
}
}
// 检测是否为握手错误
var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
if isHandshakeError {
_ = this.SetLinger(0)
}
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
if synFloodConfig != nil && synFloodConfig.IsOn {
if isHandshakeError {
this.increaseSYNFlood(synFloodConfig)
} else if err == nil && !this.hasResetSYNFlood {
this.hasResetSYNFlood = true
this.resetSYNFlood()
// 忽略白名单和局域网
if !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
if synFloodConfig != nil && synFloodConfig.IsOn {
if isHandshakeError {
this.increaseSYNFlood(synFloodConfig)
} else if err == nil && !this.hasResetSYNFlood {
this.hasResetSYNFlood = true
this.resetSYNFlood()
}
}
}
}

View File

@@ -17,6 +17,8 @@ type BaseClientConn struct {
hasLimit bool
isClosed bool
rawIP string
}
func (this *BaseClientConn) IsClosed() bool {
@@ -86,7 +88,12 @@ func (this *BaseClientConn) UserId() int64 {
// RawIP 原本IP
func (this *BaseClientConn) RawIP() string {
if len(this.rawIP) > 0 {
return this.rawIP
}
ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
this.rawIP = ip
return ip
}

View File

@@ -41,8 +41,10 @@ func (this *ClientListener) Accept() (net.Conn, error) {
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
var isInAllowList = false
if err == nil {
canGoNext, _ := iplibrary.AllowIP(ip, 0)
canGoNext, inAllowList := iplibrary.AllowIP(ip, 0)
isInAllowList = inAllowList
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if ok {
@@ -76,7 +78,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
}
}
return NewClientConn(conn, this.isTLS, this.quickClose), nil
return NewClientConn(conn, this.isTLS, this.quickClose, isInAllowList), nil
}
func (this *ClientListener) Close() error {

View File

@@ -7,6 +7,8 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strings"
"time"
)
@@ -109,6 +111,19 @@ Loop:
return err
}
// 是否请求内容过大
statusCode, ok := status.FromError(err)
if ok && statusCode.Code() == codes.ResourceExhausted {
// 去除Body
for _, accessLog := range accessLogs {
accessLog.RequestBody = nil
}
// 重新提交
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
return err
}
return err
}

View File

@@ -68,6 +68,7 @@ type HTTPRequest struct {
filePath string // 请求的文件名仅在读取Root目录下的内容时不为空
origin *serverconfigs.OriginConfig // 源站
originAddr string // 源站实际地址
originStatus int32 // 源站响应代码
errors []string // 错误信息
rewriteRule *serverconfigs.HTTPRewriteRule // 匹配到的重写规则
rewriteReplace string // 重写规则的目标
@@ -228,6 +229,14 @@ func (this *HTTPRequest) Do() {
}
}
// 防盗链
if !this.isSubRequest && this.web.Referers != nil && this.web.Referers.IsOn {
if this.doCheckReferers() {
this.doEnd()
return
}
}
// 访问控制
if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn {
if this.doAuth() {
@@ -512,6 +521,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
this.web.Auth = web.Auth
}
// referers
if web.Referers != nil && (web.Referers.IsPrior || isTop) {
this.web.Referers = web.Referers
}
// request limit
if web.RequestLimit != nil && (web.RequestLimit.IsPrior || isTop) {
this.web.RequestLimit = web.RequestLimit
@@ -1443,8 +1457,6 @@ func (this *HTTPRequest) Close() {
_ = conn.Close()
return
}
return
}
// Allow 放行
@@ -1599,9 +1611,7 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
}
// 处理自定义Response Header
func (this *HTTPRequest) processResponseHeaders(statusCode int) {
var responseHeader = this.writer.Header()
func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
// 删除/添加/替换Header
// TODO 实现AddTrailers
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {

View File

@@ -155,7 +155,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for _, subKey := range subKeys {
err := storage.Delete(subKey)
if err != nil {
remotelogs.Error("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
remotelogs.ErrorServer("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
}
}
@@ -260,7 +260,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
}
return
}
@@ -320,7 +320,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
})
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
}
return
}
@@ -372,7 +372,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-None-Match
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -384,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-Modified-Since
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -393,7 +393,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
this.processResponseHeaders(reader.Status())
this.processResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间
@@ -422,7 +422,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if supportRange {
if len(rangeHeader) > 0 {
if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -430,7 +430,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -439,7 +439,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -466,12 +466,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return
}
@@ -517,7 +517,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
})
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return true
}
@@ -554,7 +554,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
}
return
}

View File

@@ -25,10 +25,10 @@ const httpStatusPageTemplate = `<!DOCTYPE html>
</html>`
func (this *HTTPRequest) write404() {
this.writeCode(http.StatusNotFound)
this.writeCode(http.StatusNotFound, "", "")
}
func (this *HTTPRequest) writeCode(statusCode int) {
func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage string) {
if this.doPage(statusCode) {
return
}
@@ -42,12 +42,22 @@ func (this *HTTPRequest) writeCode(statusCode int) {
case "requestId":
return this.requestId
case "message":
return "" // 空
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
if len(acceptLanguages) > 0 {
var index = strings.Index(acceptLanguages, ",")
if index > 0 {
var firstLanguage = acceptLanguages[:index]
if firstLanguage == "zh-CN" {
return zhMessage
}
}
}
return enMessage
}
return "${" + varName + "}"
})
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
@@ -100,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
return "${" + varName + "}"
})
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))

View File

@@ -187,7 +187,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)

View File

@@ -34,10 +34,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
@@ -81,10 +81,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
@@ -104,10 +104,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true

View File

@@ -18,7 +18,7 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
// TODO 处理分片提交的内容
if this.web.RequestLimit.MaxBodyBytes() > 0 &&
this.RawReq.ContentLength > this.web.RequestLimit.MaxBodyBytes() {
this.writeCode(http.StatusRequestEntityTooLarge)
this.writeCode(http.StatusRequestEntityTooLarge, "", "")
return true
}
@@ -29,7 +29,7 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
clientConn, ok := requestConn.(ClientConnInterface)
if ok && !clientConn.IsBound() {
if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
this.writeCode(http.StatusTooManyRequests)
this.writeCode(http.StatusTooManyRequests, "", "")
this.Close()
return true
}

View File

@@ -146,6 +146,7 @@ func (this *HTTPRequest) log() {
if this.origin != nil {
accessLog.OriginId = this.origin.Id
accessLog.OriginAddress = this.originAddr
accessLog.OriginStatus = this.originStatus
}
// 请求Body

View File

@@ -32,7 +32,7 @@ func (this *HTTPRequest) doMismatch() {
}
// 根据配置进行相应的处理
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
if sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
// 检查cc
// TODO 可以在管理端配置是否开启以及最多尝试次数
if len(remoteIP) > 0 {
@@ -46,7 +46,7 @@ func (this *HTTPRequest) doMismatch() {
}
// 处理当前连接
var httpAllConfig = sharedNodeConfig.GlobalConfig.HTTPAll
var httpAllConfig = sharedNodeConfig.GlobalServerConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Code == "page" {
if mismatchAction.Options != nil {

View File

@@ -60,11 +60,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
@@ -99,11 +99,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}

View File

@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))

View File

@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
}
newURL := "https://" + host + this.RawReq.RequestURI
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
return true

View File

@@ -0,0 +1,45 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
"net/url"
)
func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
if this.web.Referers == nil {
return
}
var refererURL = this.RawReq.Header.Get("Referer")
if len(refererURL) == 0 {
if this.web.Referers.MatchDomain(this.ReqHost, "") {
return
}
this.tags = append(this.tags, "refererCheck")
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
u, err := url.Parse(refererURL)
if err != nil {
if this.web.Referers.MatchDomain(this.ReqHost, "") {
return
}
this.tags = append(this.tags, "refererCheck")
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
if !this.web.Referers.MatchDomain(this.ReqHost, u.Host) {
this.tags = append(this.tags, "refererCheck")
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
return
}

View File

@@ -116,7 +116,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 处理Scheme
if origin.Addr == nil {
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", err.Error())
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
return
}
@@ -168,7 +168,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
var originHostIndex = strings.Index(originAddr, ":")
if originHostIndex < 0 {
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
return
}
@@ -240,14 +240,14 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 判断是否为Websocket请求
if this.RawReq.Header.Get("Upgrade") == "websocket" {
this.doWebsocket(requestHost)
shouldRetry = this.doWebsocket(requestHost, isLastRetry)
return
}
// 获取请求客户端
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
if err != nil {
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
return
}
@@ -262,7 +262,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
this.reverseProxy.ResetScheduling()
})
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
} else if httpErr.Err != context.Canceled {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
@@ -278,7 +278,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
if httpErr.Err != io.EOF {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
}
return
@@ -292,7 +292,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
if httpErr.Err != io.EOF {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
}
} else {
// 是否为客户端方面的错误
@@ -320,6 +320,11 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
return
}
// 记录相关数据
this.originStatus = int32(resp.StatusCode)
// 恢复源站状态
if !origin.IsOk {
SharedOriginStateManager.Success(origin, func() {
this.reverseProxy.ResetScheduling()
@@ -331,7 +336,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
if this.doWAFResponse(resp) {
err = resp.Body.Close()
if err != nil {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing Error (WAF): "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing Error (WAF): "+err.Error())
}
return
}
@@ -341,7 +346,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
err = resp.Body.Close()
if err != nil {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error (Page): "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error (Page): "+err.Error())
}
return
}
@@ -392,7 +397,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"
@@ -443,13 +448,13 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
var closeErr = resp.Body.Close()
if closeErr != nil {
if !this.canIgnore(closeErr) {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
}
}
if err != nil && err != io.EOF {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
this.addError(err)
}
}

View File

@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
// 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 {
this.processResponseHeaders(this.rewriteRule.RedirectStatus)
this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
}
return true

View File

@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
set, ok := httpRequestParseRangeHeader(contentRange)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 自定义Header
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存
if len(ranges) > 0 {
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.URL) == 0 {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, err := this.writer.WriteString("The site have been shutdown.")
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
buf := utils.BytePool1k.Get()
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
} else if shutdown.BodyType == shared.BodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
this.tags = append(this.tags, "bandwidth")
var statusCode = 509
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 {

View File

@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
// Header
if statusCode <= 0 {
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
} else {
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
}
if supportVariables {

View File

@@ -173,7 +173,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) {
this.firewallPolicyId = firewallPolicy.Id
this.writeCode(http.StatusForbidden)
this.writeCode(http.StatusForbidden, "", "")
this.writer.Flush()
this.writer.Close()
@@ -192,7 +192,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) {
this.firewallPolicyId = firewallPolicy.Id
this.writeCode(http.StatusForbidden)
this.writeCode(http.StatusForbidden, "", "")
this.writer.Flush()
this.writer.Close()

View File

@@ -1,6 +1,7 @@
package nodes
import (
"bufio"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
@@ -9,7 +10,7 @@ import (
)
// 处理Websocket请求
func (this *HTTPRequest) doWebsocket(requestHost string) {
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
@@ -43,13 +44,16 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
if isLastRetry {
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
}
// 增加失败次数
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
shouldRetry = true
return
}
@@ -79,6 +83,33 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
}()
go func() {
// 读取第一个响应
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
if err != nil {
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
// 将响应写回客户端
err = resp.Write(clientConn)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
if resp.Body != nil {
_ = resp.Body.Close()
}
// 复制剩余的数据
var buf = utils.BytePool4k.Get()
defer utils.BytePool4k.Put(buf)
for {
@@ -98,4 +129,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
_ = originConn.Close()
}()
_, _ = io.Copy(originConn, clientConn)
return
}

View File

@@ -7,14 +7,16 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"strings"
"sync"
)
type Listener struct {
group *serverconfigs.ServerAddressGroup
isListening bool
listener ListenerInterface // 监听器
group *serverconfigs.ServerAddressGroup
listener ListenerInterface // 监听器
locker sync.RWMutex
}
@@ -118,18 +120,64 @@ func (this *Listener) listenTCP() error {
}
func (this *Listener) listenUDP() error {
listener, err := this.createUDPListener()
var addr = this.group.Addr()
var ipv4PacketListener *ipv4.PacketConn
var ipv6PacketListener *ipv6.PacketConn
host, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
if len(host) == 0 {
// ipv4
ipv4Listener, err := this.createUDPIPv4Listener()
if err == nil {
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
}
// ipv6
ipv6Listener, err := this.createUDPIPv6Listener()
if err == nil {
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
}
} else if strings.Contains(host, ":") { // ipv6
ipv6Listener, err := this.createUDPIPv6Listener()
if err == nil {
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
}
} else { // ipv4
ipv4Listener, err := this.createUDPIPv4Listener()
if err == nil {
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
}
}
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = listener.Close()
if ipv4PacketListener != nil {
_ = ipv4PacketListener.Close()
}
if ipv6PacketListener != nil {
_ = ipv6PacketListener.Close()
}
})
this.listener = &UDPListener{
BaseListener: BaseListener{Group: this.group},
Listener: listener,
IPv4Listener: ipv4PacketListener,
IPv6Listener: ipv6PacketListener,
}
goman.New(func() {
@@ -168,12 +216,20 @@ func (this *Listener) createTCPListener() (net.Listener, error) {
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
}
// 创建UDP监听器
func (this *Listener) createUDPListener() (*net.UDPConn, error) {
// TODO 将来支持udp4/udp6
// 创建UDP IPv4监听器
func (this *Listener) createUDPIPv4Listener() (*net.UDPConn, error) {
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
if err != nil {
return nil, err
}
return net.ListenUDP("udp", addr)
return net.ListenUDP("udp4", addr)
}
// 创建UDP监听器
func (this *Listener) createUDPIPv6Listener() (*net.UDPConn, error) {
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
if err != nil {
return nil, err
}
return net.ListenUDP("udp6", addr)
}

View File

@@ -3,11 +3,12 @@ package nodes
import (
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"net"
)
type BaseListener struct {
@@ -75,7 +76,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
// 如果域名为空,则取第一个
// 通常域名为空是因为是直接通过IP访问的
if len(domain) == 0 {
if group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
if group.IsHTTPS() && sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly {
return nil, nil, errors.New("no tls server name matched")
}
@@ -131,19 +132,19 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
return
}
var matchDomainStrictly = sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
if sharedNodeConfig.GlobalConfig != nil &&
len(sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name)) {
defaultDomain := sharedNodeConfig.GlobalConfig.HTTPAll.DefaultDomain
if sharedNodeConfig.GlobalServerConfig != nil &&
len(sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) || (sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP && net.ParseIP(name) != nil)) {
var defaultDomain = sharedNodeConfig.GlobalServerConfig.HTTPAll.DefaultDomain
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
if serverConfig != nil {
return
}
}
if matchDomainStrictly && !lists.ContainsString(sharedNodeConfig.GlobalConfig.HTTPAll.AllowMismatchDomains, name) {
if matchDomainStrictly && !configutils.MatchDomains(sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!sharedNodeConfig.GlobalServerConfig.HTTPAll.AllowNodeIP || net.ParseIP(name) == nil) {
return
}
@@ -170,7 +171,7 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
}
// 是否严格匹配域名
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个
var currentServers = group.Servers()
@@ -181,23 +182,3 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, name
}
// 使用CNAME来查找服务
// TODO 防止单IP随机生成域名攻击
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
if !sharedNodeConfig.SupportCNAME {
return nil
}
var realName = sharedCNAMEManager.Lookup(domain)
if len(realName) == 0 {
return nil
}
group := this.Group
if group == nil {
return nil
}
return group.MatchServerCNAME(realName)
}

View File

@@ -1,7 +1,6 @@
package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
@@ -9,6 +8,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
@@ -213,15 +213,14 @@ func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) st
option = "u"
}
var cmd = exec.Command(path, "-"+option+"lpn", "sport = :"+port)
var output = &bytes.Buffer{}
cmd.Stdout = output
var cmd = executils.NewTimeoutCmd(10*time.Second, path, "-"+option+"lpn", "sport = :"+port)
cmd.WithStdout()
err = cmd.Run()
if err != nil {
return ""
}
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(output.String())
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(cmd.Stdout())
if len(matches) > 1 {
return matches[1]
}

View File

@@ -9,6 +9,8 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"github.com/pires/go-proxyproto"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"strings"
"sync"
@@ -19,10 +21,57 @@ const (
UDPConnLifeSeconds = 30
)
type UDPPacketListener interface {
ReadFrom(b []byte) (n int, cm any, src net.Addr, err error)
WriteTo(b []byte, cm any, dst net.Addr) (n int, err error)
LocalAddr() net.Addr
}
type UDPIPv4Listener struct {
rawListener *ipv4.PacketConn
}
func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener {
return &UDPIPv4Listener{rawListener: rawListener}
}
func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return this.rawListener.ReadFrom(b)
}
func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
}
func (this *UDPIPv4Listener) LocalAddr() net.Addr {
return this.rawListener.LocalAddr()
}
type UDPIPv6Listener struct {
rawListener *ipv6.PacketConn
}
func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener {
return &UDPIPv6Listener{rawListener: rawListener}
}
func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return this.rawListener.ReadFrom(b)
}
func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
}
func (this *UDPIPv6Listener) LocalAddr() net.Addr {
return this.rawListener.LocalAddr()
}
type UDPListener struct {
BaseListener
Listener *net.UDPConn
IPv4Listener *ipv4.PacketConn
IPv6Listener *ipv6.PacketConn
connMap map[string]*UDPConn
connLocker sync.Mutex
@@ -36,6 +85,60 @@ type UDPListener struct {
}
func (this *UDPListener) Serve() error {
if this.Group == nil {
return nil
}
var server = this.Group.FirstServer()
if server == nil {
return nil
}
var serverId = server.Id
var wg = &sync.WaitGroup{}
wg.Add(2) // 2 = ipv4 + ipv6
go func() {
defer wg.Done()
if this.IPv4Listener != nil {
err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true)
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
return
}
err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener))
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
return
}
}
}()
go func() {
defer wg.Done()
if this.IPv6Listener != nil {
err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true)
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
return
}
err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener))
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
return
}
}
}()
wg.Wait()
return nil
}
func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
// 获取分组端口
var groupAddr = this.Group.Addr()
var portIndex = strings.LastIndex(groupAddr, ":")
@@ -67,7 +170,7 @@ func (this *UDPListener) Serve() error {
return nil
}
n, addr, err := this.Listener.ReadFrom(buffer)
n, cm, clientAddr, err := listener.ReadFrom(buffer)
if err != nil {
if this.isClosed {
return nil
@@ -77,14 +180,14 @@ func (this *UDPListener) Serve() error {
if n > 0 {
this.connLocker.Lock()
conn, ok := this.connMap[addr.String()]
conn, ok := this.connMap[clientAddr.String()]
this.connLocker.Unlock()
if ok && !conn.IsOk() {
_ = conn.Close()
ok = false
}
if !ok {
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr)
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
continue
@@ -93,9 +196,9 @@ func (this *UDPListener) Serve() error {
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
continue
}
conn = NewUDPConn(firstServer, addr, this.Listener, originConn.(*net.UDPConn))
conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn))
this.connLocker.Lock()
this.connMap[addr.String()] = conn
this.connMap[clientAddr.String()] = conn
this.connLocker.Unlock()
}
_, _ = conn.Write(buffer[:n])
@@ -117,7 +220,26 @@ func (this *UDPListener) Close() error {
}
this.connLocker.Unlock()
return this.Listener.Close()
var errorStrings = []string{}
if this.IPv4Listener != nil {
err := this.IPv4Listener.Close()
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
}
if this.IPv6Listener != nil {
err := this.IPv6Listener.Close()
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
}
if len(errorStrings) > 0 {
return errors.New(errorStrings[0])
}
return nil
}
func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
@@ -132,7 +254,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.reverseProxy = firstServer.ReverseProxy
}
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
@@ -181,12 +303,12 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.UDPv6
}
header := proxyproto.Header{
var header = proxyproto.Header{
Version: byte(reverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: remoteAddr,
DestinationAddr: this.Listener.LocalAddr(),
DestinationAddr: localAddr,
}
_, err = header.WriteTo(conn)
if err != nil {
@@ -224,21 +346,21 @@ func (this *UDPListener) gcConns() {
// UDPConn 自定义的UDP连接管理
type UDPConn struct {
addr net.Addr
proxyConn net.Conn
serverConn net.Conn
activatedAt int64
isOk bool
isClosed bool
addr net.Addr
proxyListener UDPPacketListener
serverConn net.Conn
activatedAt int64
isOk bool
isClosed bool
}
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn {
var conn = &UDPConn{
addr: addr,
proxyConn: proxyConn,
serverConn: serverConn,
activatedAt: time.Now().Unix(),
isOk: true,
addr: addr,
proxyListener: proxyListener,
serverConn: serverConn,
activatedAt: time.Now().Unix(),
isOk: true,
}
// 统计
@@ -246,6 +368,14 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
// 处理ControlMessage
switch controlMessage := cm.(type) {
case *ipv4.ControlMessage:
controlMessage.Src = controlMessage.Dst
case *ipv6.ControlMessage:
controlMessage.Src = controlMessage.Dst
}
goman.New(func() {
var buffer = utils.BytePool4k.Get()
defer func() {
@@ -256,7 +386,8 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
n, err := serverConn.Read(buffer)
if n > 0 {
conn.activatedAt = time.Now().Unix()
_, writingErr := proxyConn.WriteTo(buffer[:n], addr)
_, writingErr := proxyListener.WriteTo(buffer[:n], cm, addr)
if writingErr != nil {
conn.isOk = false
break

View File

@@ -2,6 +2,7 @@ package nodes
import (
"bytes"
"context"
"encoding/json"
"errors"
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
@@ -23,7 +24,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/clock"
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/andybalholm/brotli"
"github.com/iwind/TeaGo/Tea"
@@ -91,6 +92,10 @@ func (this *Node) Test() error {
// Start 启动
func (this *Node) Start() {
// 设置netdns
// 这个需要放在所有网络访问的最前面
_ = os.Setenv("GODEBUG", "netdns=go")
_, ok := os.LookupEnv("EdgeDaemon")
if ok {
remotelogs.Println("NODE", "start from daemon")
@@ -166,11 +171,6 @@ func (this *Node) Start() {
NewNodeStatusExecutor().Listen()
})
// 同步时间
goman.New(func() {
clock.Start()
})
// 读取配置
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil {
@@ -322,139 +322,157 @@ func (this *Node) loop() error {
return errors.New("read node tasks failed: " + err.Error())
}
for _, task := range tasksResp.NodeTasks {
switch task.Type {
case "ipItemChanged":
// 防止阻塞
select {
case iplibrary.IPListUpdateNotify <- true:
default:
}
// 修改为已同步
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
case "configChanged":
if task.ServerId > 0 {
err = this.syncServerConfig(task.ServerId)
} else {
if !task.IsPrimary {
// 我们等等主节点配置准备完毕
time.Sleep(2 * time.Second)
}
err = this.syncConfig(task.Version)
}
if err != nil {
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: false,
Error: err.Error(),
})
} else {
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
}
if err != nil {
return err
}
case "nodeVersionChanged":
if !sharedUpgradeManager.IsInstalling() {
goman.New(func() {
sharedUpgradeManager.Start()
})
}
case "scriptsChanged":
err = this.reloadCommonScripts()
if err != nil {
return errors.New("reload common scripts failed: " + err.Error())
}
// 修改为已同步
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
case "nodeLevelChanged":
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{})
if err != nil {
return err
}
sharedNodeConfig.Level = levelInfoResp.Level
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
if err != nil {
return errors.New("decode level info failed: " + err.Error())
}
}
sharedNodeConfig.ParentNodes = parentNodes
// 修改为已同步
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
case "ddosProtectionChanged":
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = nil
}
} else {
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return errors.New("decode DDoS protection config failed: " + err.Error())
}
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
}
}
// 修改为已同步
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
}
err := this.execTask(rpcClient, nodeCtx, task)
this.finishTask(task.Id, err)
}
return nil
}
// 执行任务
func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, task *pb.NodeTask) error {
switch task.Type {
case "ipItemChanged":
// 防止阻塞
select {
case iplibrary.IPListUpdateNotify <- true:
default:
}
case "configChanged":
if task.ServerId > 0 {
return this.syncServerConfig(task.ServerId)
}
if !task.IsPrimary {
// 我们等等主节点配置准备完毕
time.Sleep(2 * time.Second)
}
return this.syncConfig(task.Version)
case "nodeVersionChanged":
if !sharedUpgradeManager.IsInstalling() {
goman.New(func() {
sharedUpgradeManager.Start()
})
}
case "scriptsChanged":
err := this.reloadCommonScripts()
if err != nil {
return errors.New("reload common scripts failed: " + err.Error())
}
case "nodeLevelChanged":
levelInfoResp, err := rpcClient.NodeRPC.FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{})
if err != nil {
return err
}
if sharedNodeConfig != nil {
sharedNodeConfig.Level = levelInfoResp.Level
}
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
if len(levelInfoResp.ParentNodesMapJSON) > 0 {
err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes)
if err != nil {
return errors.New("decode level info failed: " + err.Error())
}
}
if sharedNodeConfig != nil {
sharedNodeConfig.ParentNodes = parentNodes
}
case "ddosProtectionChanged":
resp, err := rpcClient.NodeRPC.FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = nil
}
return nil
}
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return errors.New("decode DDoS protection config failed: " + err.Error())
}
if ddosProtectionConfig != nil && sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Warn("NODE", "apply DDoS protection failed: "+err.Error())
return nil
}
case "globalServerConfigChanged":
resp, err := rpcClient.NodeRPC.FindNodeGlobalServerConfig(nodeCtx, &pb.FindNodeGlobalServerConfigRequest{})
if err != nil {
return err
}
if len(resp.GlobalServerConfigJSON) > 0 {
var globalServerConfig = serverconfigs.DefaultGlobalServerConfig()
err = json.Unmarshal(resp.GlobalServerConfigJSON, globalServerConfig)
if err != nil {
return errors.New("decode global server config failed: " + err.Error())
}
if globalServerConfig != nil {
err = globalServerConfig.Init()
if err != nil {
return errors.New("validate global server config failed: " + err.Error())
}
if sharedNodeConfig != nil {
sharedNodeConfig.GlobalServerConfig = globalServerConfig
}
}
}
default:
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
}
return nil
}
// 标记任务完成
func (this *Node) finishTask(taskId int64, err error) {
if taskId <= 0 {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
logs.Println("[NODE]", "create rpc client failed: "+err.Error())
return
}
var nodeCtx = rpcClient.Context()
var isOk = err == nil
var errMsg = ""
if err != nil {
errMsg = err.Error()
}
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: taskId,
IsOk: isOk,
Error: errMsg,
})
if err != nil {
// 不需要上报到服务中心
if rpc.IsConnError(err) {
logs.Println("[NODE]", "report task done failed: "+err.Error())
} else {
remotelogs.Error("NODE", "report task done failed: "+err.Error())
}
}
}
// 读取API配置
func (this *Node) syncConfig(taskVersion int64) error {
this.locker.Lock()
@@ -483,7 +501,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
}
// 获取同步任务
nodeCtx := rpcClient.Context()
var nodeCtx = rpcClient.Context()
// TODO 这里考虑只同步版本号有变更的
configResp, err := rpcClient.NodeRPC.FindCurrentNodeConfig(nodeCtx, &pb.FindCurrentNodeConfigRequest{

View File

@@ -1,48 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"strings"
"sync"
"time"
)
var sharedCNAMEManager = NewServerCNAMEManager()
// ServerCNAMEManager 服务CNAME管理
// TODO 需要自动更新缓存里的记录
type ServerCNAMEManager struct {
ttlCache *ttlcache.Cache
locker sync.Mutex
}
func NewServerCNAMEManager() *ServerCNAMEManager {
return &ServerCNAMEManager{
ttlCache: ttlcache.NewCache(),
}
}
func (this *ServerCNAMEManager) Lookup(domain string) string {
if len(domain) == 0 {
return ""
}
var item = this.ttlCache.Read(domain)
if item != nil {
return types.String(item.Value)
}
cname, _ := utils.LookupCNAME(domain)
if len(cname) > 0 {
cname = strings.TrimSuffix(cname, ".")
}
this.ttlCache.Write(domain, cname, time.Now().Unix()+600)
return cname
}

View File

@@ -1,19 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"testing"
"time"
)
func TestServerCNameManager_Lookup(t *testing.T) {
var cnameManager = NewServerCNAMEManager()
t.Log(cnameManager.Lookup("www.yun4s.cn"))
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(cnameManager.Lookup("www.yun4s.cn"))
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/maps"
"os"
"os/exec"
@@ -20,15 +21,18 @@ import (
func init() {
var manager = NewSystemServiceManager()
events.On(events.EventReload, func() {
err := manager.Setup()
if err != nil {
remotelogs.Error("SYSTEM_SERVICE", "setup system services failed: "+err.Error())
}
goman.New(func() {
err := manager.Setup()
if err != nil {
remotelogs.Error("SYSTEM_SERVICE", "setup system services failed: "+err.Error())
}
})
})
}
// SystemServiceManager 系统服务管理
type SystemServiceManager struct {
lastIsOn int // -1, 0, 1
}
func NewSystemServiceManager() *SystemServiceManager {
@@ -68,7 +72,8 @@ func (this *SystemServiceManager) setupSystemd(params maps.Map) error {
if err != nil {
return err
}
config := &nodeconfigs.SystemdServiceConfig{}
var config = &nodeconfigs.SystemdServiceConfig{}
err = json.Unmarshal(data, config)
if err != nil {
return err
@@ -82,42 +87,60 @@ func (this *SystemServiceManager) setupSystemd(params maps.Map) error {
if len(systemctl) == 0 {
return errors.New("can not find 'systemctl' on the system")
}
cmd := utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run()
if err != nil {
return err
// 记录上次状态
var isOnInt int
if config.IsOn {
isOnInt = 1
} else {
isOnInt = 0
}
if this.lastIsOn == isOnInt {
return nil
}
defer func() {
this.lastIsOn = isOnInt
}()
var shortName = teaconst.SystemdServiceName
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
cmd.WithStdout()
err = cmd.Run()
var hasInstalled = err == nil
if config.IsOn {
exe, err := os.Executable()
if err != nil {
return err
}
// 启动Service
goman.New(func() {
time.Sleep(5 * time.Second)
_ = exec.Command(systemctl, "start", teaconst.SystemdServiceName).Start()
})
if output == "enabled" {
// 检查文件路径是否变化
// 检查文件路径是否变化
if hasInstalled && cmd.Stdout() == "enabled" {
data, err := os.ReadFile("/etc/systemd/system/" + teaconst.SystemdServiceName + ".service")
if err == nil && bytes.Index(data, []byte(exe)) > 0 {
return nil
}
}
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
// 安装服务
var manager = utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Install(exe, []string{})
if err != nil {
return err
}
// 启动服务
goman.New(func() {
time.Sleep(5 * time.Second)
_ = executils.NewTimeoutCmd(30*time.Second, systemctl, "start", teaconst.SystemdServiceName).Start()
})
} else {
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Uninstall()
if err != nil {
return err
if hasInstalled {
var manager = utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Uninstall()
if err != nil {
return err
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/url"
"sort"
"strings"
@@ -152,7 +153,7 @@ func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithInsecure(), grpc.WithBlock())
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
} else if u.Scheme == "https" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,

View File

@@ -1,14 +1,15 @@
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/Tea"
"net"
"os"
"os/exec"
"strings"
"time"
)
@@ -61,12 +62,16 @@ func (this *TOAManager) Run(config *nodeconfigs.TOAConfig) error {
}
remotelogs.Println("TOA", "starting ...")
remotelogs.Println("TOA", "args: "+strings.Join(config.AsArgs(), " "))
cmd := exec.Command(binPath, config.AsArgs()...)
cmd := executils.NewCmd(binPath, config.AsArgs()...)
err = cmd.Start()
if err != nil {
return err
}
this.pid = cmd.Process.Pid
var process = cmd.Process()
if process == nil {
return errors.New("start failed")
}
this.pid = process.Pid
goman.New(func() {
_ = cmd.Wait()

View File

@@ -12,11 +12,11 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/Tea"
stringutil "github.com/iwind/TeaGo/utils/string"
"github.com/iwind/gosock/pkg/gosock"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
@@ -252,7 +252,7 @@ func (this *UpgradeManager) restart() error {
// 启动
exe = filepath.Dir(exe) + "/" + teaconst.ProcessName
var cmd = exec.Command(exe, "start")
var cmd = executils.NewCmd(exe, "start")
err = cmd.Start()
if err != nil {
return err

View File

@@ -11,7 +11,6 @@ import (
)
var prefixReg = regexp.MustCompile(`^\(\?([\w\s]+)\)`) // (?x)
var prefixReg2 = regexp.MustCompile(`^\(\?([\w\s]*:)`) // (?x: ...
var braceZeroReg = regexp.MustCompile(`^{\s*0*\s*}`) // {0}
var braceZeroReg2 = regexp.MustCompile(`^{\s*0*\s*,`) // {0, x}

View File

@@ -21,7 +21,7 @@ var logChan = make(chan *pb.NodeLog, 1024)
func init() {
// 定期上传日志
ticker := time.NewTicker(60 * time.Second)
var ticker = time.NewTicker(60 * time.Second)
if Tea.IsTesting() {
ticker = time.NewTicker(10 * time.Second)
}
@@ -37,6 +37,11 @@ func init() {
})
}
// Debug 打印调试信息
func Debug(tag string, description string) {
logs.Println("[" + tag + "]" + description)
}
// Println 打印普通信息
func Println(tag string, description string) {
logs.Println("[" + tag + "]" + description)
@@ -73,6 +78,31 @@ func Warn(tag string, description string) {
}
}
// WarnServer 打印服务相关警告
func WarnServer(tag string, description string) {
if Tea.IsTesting() {
logs.Println("[" + tag + "]" + description)
}
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
return
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: "warning",
NodeId: teaconst.NodeId,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// Error 打印错误信息
func Error(tag string, description string) {
logs.Println("[" + tag + "]" + description)
@@ -97,6 +127,37 @@ func Error(tag string, description string) {
}
}
// ErrorServer 打印服务相关错误信息
func ErrorServer(tag string, description string) {
if Tea.IsTesting() {
logs.Println("[" + tag + "]" + description)
}
// 忽略RPC连接错误
var level = "error"
if strings.Contains(description, "code = Unavailable desc") {
level = "warning"
}
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
return
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: level,
NodeId: teaconst.NodeId,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// ErrorObject 打印错误对象
func ErrorObject(tag string, err error) {
if err == nil {
@@ -111,7 +172,15 @@ func ErrorObject(tag string, err error) {
// ServerError 打印服务相关错误信息
func ServerError(serverId int64, tag string, description string, logType nodeconfigs.NodeLogType, params maps.Map) {
logs.Println("[" + tag + "]" + description)
if Tea.IsTesting() {
logs.Println("[" + tag + "]" + description)
}
// 是否记录服务相关错误
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && nodeConfig.GlobalServerConfig != nil && !nodeConfig.GlobalServerConfig.Log.RecordServerError {
return
}
// 参数
var paramsJSON []byte
@@ -207,7 +276,7 @@ func ServerLog(serverId int64, tag string, description string, logType nodeconfi
// 上传日志
func uploadLogs() error {
logList := []*pb.NodeLog{}
var logList = []*pb.NodeLog{}
const hashSize = 5
var hashList = []uint64{}
@@ -242,6 +311,7 @@ Loop:
if len(logList) == 0 {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err

View File

@@ -167,10 +167,13 @@ func (this *RPCClient) init() error {
return errors.New("parse endpoint failed: " + err.Error())
}
var conn *grpc.ClientConn
var callOptions = grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(128*1024*1024),
grpc.UseCompressor(gzip.Name))
var callOptions = grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(128*1024*1024),
grpc.MaxCallSendMsgSize(128*1024*1024),
grpc.UseCompressor(gzip.Name),
)
if u.Scheme == "http" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions)
} else if u.Scheme == "https" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,

View File

@@ -0,0 +1,158 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clock
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
timeutil "github.com/iwind/TeaGo/utils/time"
"net"
"os/exec"
"runtime"
"time"
)
var hasSynced = false
var sharedClockManager = NewClockManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(sharedClockManager.Start)
})
events.On(events.EventReload, func() {
if !hasSynced {
hasSynced = true
goman.New(func() {
err := sharedClockManager.Sync()
if err != nil {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
}
})
}
})
}
type ClockManager struct {
}
func NewClockManager() *ClockManager {
return &ClockManager{}
}
// Start 启动
func (this *ClockManager) Start() {
var ticker = time.NewTicker(1 * time.Hour)
for range ticker.C {
err := this.Sync()
if err != nil {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
}
}
}
// Sync 自动校对时间
func (this *ClockManager) Sync() error {
if runtime.GOOS != "linux" {
return nil
}
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig == nil {
return nil
}
var config = nodeConfig.Clock
if config == nil || !config.AutoSync {
return nil
}
var server = config.Server
if len(server) == 0 {
server = "pool.ntp.org"
}
ntpdate, err := exec.LookPath("ntpdate")
if err != nil {
// 使用 date 命令设置
// date --set TIME
dateExe, err := exec.LookPath("date")
if err == nil {
currentTime, err := this.ReadServer(server)
if err != nil {
return errors.New("read server failed: " + err.Error())
}
var delta = time.Now().Unix() - currentTime.Unix()
if delta > 1 || delta < -1 { // 相差比较大的时候才会同步
var err = executils.NewTimeoutCmd(3*time.Second, dateExe, "--set", timeutil.Format("Y-m-d H:i:s+P", currentTime)).
Run()
if err != nil {
return err
}
}
}
return nil
}
if len(ntpdate) > 0 {
return this.syncNtpdate(ntpdate, server)
}
return nil
}
func (this *ClockManager) syncNtpdate(ntpdate string, server string) error {
var cmd = executils.NewTimeoutCmd(30*time.Second, ntpdate, server)
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return errors.New(err.Error() + ": " + cmd.Stderr())
}
return nil
}
// 参考自https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
func (this *ClockManager) ReadServer(server string) (time.Time, error) {
conn, err := net.Dial("udp", server+":123")
if err != nil {
return time.Time{}, errors.New("connect to server failed: " + err.Error())
}
defer func() {
_ = conn.Close()
}()
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return time.Time{}, err
}
// configure request settings by specifying the first byte as
// 00 011 011 (or 0x1B)
// | | +-- client mode (3)
// | + ----- version (3)
// + -------- leap year indicator, 0 no warning
var req = &NTPPacket{Settings: 0x1B}
err = binary.Write(conn, binary.BigEndian, req)
if err != nil {
return time.Time{}, errors.New("write request failed: " + err.Error())
}
var resp = &NTPPacket{}
err = binary.Read(conn, binary.BigEndian, resp)
if err != nil {
return time.Time{}, errors.New("write server response failed: " + err.Error())
}
const ntpEpochOffset = 2208988800
var secs = float64(resp.TxTimeSec) - ntpEpochOffset
var nanos = (int64(resp.TxTimeFrac) * 1e9) >> 32
return time.Unix(int64(secs), nanos), nil
}

View File

@@ -0,0 +1,12 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clock_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/clock"
"testing"
)
func TestReadServer(t *testing.T) {
t.Log(clock.NewClockManager().ReadServer("pool.ntp.org"))
}

View File

@@ -0,0 +1,21 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clock
type NTPPacket struct {
Settings uint8 // leap yr indicator, ver number, and mode
Stratum uint8 // stratum of local clock
Poll int8 // poll exponent
Precision int8 // precision exponent
RootDelay uint32 // root delay
RootDispersion uint32 // root dispersion
ReferenceID uint32 // reference id
RefTimeSec uint32 // reference timestamp sec
RefTimeFrac uint32 // reference timestamp fractional
OrigTimeSec uint32 // origin time secs
OrigTimeFrac uint32 // origin time fractional
RxTimeSec uint32 // receive time secs
RxTimeFrac uint32 // receive time frac
TxTimeSec uint32 // transmit time secs
TxTimeFrac uint32 // transmit time frac
}

View File

@@ -1,58 +0,0 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clock
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"os/exec"
"runtime"
"time"
)
// Start TODO 需要可以在集群中配置
func Start() {
// sync once
err := Sync()
if err != nil {
remotelogs.Warn("CLOCK", "sync time clock failed: "+err.Error())
}
var ticker = time.NewTicker(1 * time.Hour)
for range ticker.C {
err := Sync()
if err != nil {
// ignore error
}
}
}
// Sync 自动校对时间
func Sync() error {
if runtime.GOOS != "linux" {
return nil
}
ntpdate, err := exec.LookPath("ntpdate")
if err != nil {
return nil
}
if len(ntpdate) > 0 {
return syncNtpdate(ntpdate)
}
return nil
}
func syncNtpdate(ntpdate string) error {
var cmd = exec.Command(ntpdate, "pool.ntp.org")
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New(err.Error() + ": " + stderr.String())
}
return nil
}

View File

@@ -1,7 +0,0 @@
package utils
// 命令定义
type Command struct {
Name string
Args []string
}

View File

@@ -1,61 +0,0 @@
package utils
import (
"bytes"
"errors"
"os/exec"
)
// 命令执行器
type CommandExecutor struct {
commands []*Command
}
// 获取新对象
func NewCommandExecutor() *CommandExecutor {
return &CommandExecutor{}
}
// 添加命令
func (this *CommandExecutor) Add(command string, arg ...string) {
this.commands = append(this.commands, &Command{
Name: command,
Args: arg,
})
}
// 执行命令
func (this *CommandExecutor) Run() (output string, err error) {
if len(this.commands) == 0 {
return "", errors.New("no commands no run")
}
var lastCmd *exec.Cmd = nil
var lastData []byte = nil
for _, command := range this.commands {
cmd := exec.Command(command.Name, command.Args...)
stdout := bytes.NewBuffer([]byte{})
cmd.Stdout = stdout
if lastCmd != nil {
cmd.Stdin = bytes.NewBuffer(lastData)
}
err = cmd.Start()
if err != nil {
return "", err
}
err = cmd.Wait()
if err != nil {
_, ok := err.(*exec.ExitError)
if ok {
return "", nil
}
return "", err
}
lastData = stdout.Bytes()
lastCmd = cmd
}
return string(bytes.TrimSpace(lastData)), nil
}

View File

@@ -64,6 +64,11 @@ For:
for {
// closed
if this.isClosed {
if lastTx != nil {
_ = lastTx.Commit()
lastTx = nil
}
return
}

162
internal/utils/exec/cmd.go Normal file
View File

@@ -0,0 +1,162 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package executils
import (
"bytes"
"context"
"os"
"os/exec"
"strings"
"time"
)
type Cmd struct {
name string
args []string
env []string
dir string
ctx context.Context
timeout time.Duration
cancelFunc func()
captureStdout bool
captureStderr bool
stdout *bytes.Buffer
stderr *bytes.Buffer
rawCmd *exec.Cmd
}
func NewCmd(name string, args ...string) *Cmd {
return &Cmd{
name: name,
args: args,
}
}
func NewTimeoutCmd(timeout time.Duration, name string, args ...string) *Cmd {
return (&Cmd{
name: name,
args: args,
}).WithTimeout(timeout)
}
func (this *Cmd) WithTimeout(timeout time.Duration) *Cmd {
this.timeout = timeout
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
this.ctx = ctx
this.cancelFunc = cancelFunc
return this
}
func (this *Cmd) WithStdout() *Cmd {
this.captureStdout = true
return this
}
func (this *Cmd) WithStderr() *Cmd {
this.captureStderr = true
return this
}
func (this *Cmd) WithEnv(env []string) *Cmd {
this.env = env
return this
}
func (this *Cmd) WithDir(dir string) *Cmd {
this.dir = dir
return this
}
func (this *Cmd) Start() error {
var cmd = this.compose()
return cmd.Start()
}
func (this *Cmd) Wait() error {
var cmd = this.compose()
return cmd.Wait()
}
func (this *Cmd) Run() error {
if this.cancelFunc != nil {
defer this.cancelFunc()
}
var cmd = this.compose()
return cmd.Run()
}
func (this *Cmd) RawStdout() string {
if this.stdout != nil {
return this.stdout.String()
}
return ""
}
func (this *Cmd) Stdout() string {
return strings.TrimSpace(this.RawStdout())
}
func (this *Cmd) RawStderr() string {
if this.stderr != nil {
return this.stderr.String()
}
return ""
}
func (this *Cmd) Stderr() string {
return strings.TrimSpace(this.RawStderr())
}
func (this *Cmd) String() string {
if this.rawCmd != nil {
return this.rawCmd.String()
}
var newCmd = exec.Command(this.name, this.args...)
return newCmd.String()
}
func (this *Cmd) Process() *os.Process {
if this.rawCmd != nil {
return this.rawCmd.Process
}
return nil
}
func (this *Cmd) compose() *exec.Cmd {
if this.rawCmd != nil {
return this.rawCmd
}
if this.ctx != nil {
this.rawCmd = exec.CommandContext(this.ctx, this.name, this.args...)
} else {
this.rawCmd = exec.Command(this.name, this.args...)
}
if this.env != nil {
this.rawCmd.Env = this.env
}
if len(this.dir) > 0 {
this.rawCmd.Dir = this.dir
}
if this.captureStdout {
this.stdout = &bytes.Buffer{}
this.rawCmd.Stdout = this.stdout
}
if this.captureStderr {
this.stderr = &bytes.Buffer{}
this.rawCmd.Stderr = this.stderr
}
return this.rawCmd
}

View File

@@ -0,0 +1,61 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package executils_test
import (
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"testing"
"time"
)
func TestNewTimeoutCmd_Sleep(t *testing.T) {
var cmd = executils.NewTimeoutCmd(1*time.Second, "sleep", "3")
cmd.WithStdout()
cmd.WithStderr()
err := cmd.Run()
t.Log("error:", err)
t.Log("stdout:", cmd.Stdout())
t.Log("stderr:", cmd.Stderr())
}
func TestNewTimeoutCmd_Echo(t *testing.T) {
var cmd = executils.NewTimeoutCmd(10*time.Second, "echo", "-n", "hello")
cmd.WithStdout()
cmd.WithStderr()
err := cmd.Run()
t.Log("error:", err)
t.Log("stdout:", cmd.Stdout())
t.Log("stderr:", cmd.Stderr())
}
func TestNewTimeoutCmd_Echo2(t *testing.T) {
var cmd = executils.NewCmd("echo", "hello")
cmd.WithStdout()
cmd.WithStderr()
err := cmd.Run()
t.Log("error:", err)
t.Log("stdout:", cmd.Stdout())
t.Log("raw stdout:", cmd.RawStdout())
t.Log("stderr:", cmd.Stderr())
t.Log("raw stderr:", cmd.RawStderr())
}
func TestNewTimeoutCmd_Echo3(t *testing.T) {
var cmd = executils.NewCmd("echo", "-n", "hello")
err := cmd.Run()
t.Log("error:", err)
t.Log("stdout:", cmd.Stdout())
t.Log("stderr:", cmd.Stderr())
}
func TestCmd_Process(t *testing.T) {
var cmd = executils.NewCmd("echo", "-n", "hello")
err := cmd.Run()
t.Log("error:", err)
t.Log(cmd.Process())
}
func TestNewTimeoutCmd_String(t *testing.T) {
var cmd = executils.NewCmd("echo", "-n", "hello")
t.Log("stdout:", cmd.String())
}

View File

@@ -2,13 +2,9 @@
package expires
import "sync"
type IdKeyMap struct {
idKeys map[int64]string // id => key
keyIds map[string]int64 // key => id
locker sync.Mutex
}
func NewIdKeyMap() *IdKeyMap {

View File

@@ -5,12 +5,9 @@ import (
"github.com/cespare/xxhash"
"math"
"net"
"regexp"
"strings"
)
var ipv4Reg = regexp.MustCompile(`\d+\.`)
// IP2Long 将IP转换为整型
// 注意IPv6没有顺序
func IP2Long(ip string) uint64 {

View File

@@ -1,3 +1,4 @@
//go:build linux
// +build linux
package utils
@@ -5,17 +6,19 @@ package utils
import (
"errors"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
"os"
"os/exec"
"regexp"
"time"
)
var systemdServiceFile = "/etc/systemd/system/edge-node.service"
var initServiceFile = "/etc/init.d/" + teaconst.SystemdServiceName
// 安装服务
// Install 安装服务
func (this *ServiceManager) Install(exePath string, args []string) error {
if os.Getgid() != 0 {
return errors.New("only root users can install the service")
@@ -29,7 +32,7 @@ func (this *ServiceManager) Install(exePath string, args []string) error {
return this.installSystemdService(systemd, exePath, args)
}
// 启动服务
// Start 启动服务
func (this *ServiceManager) Start() error {
if os.Getgid() != 0 {
return errors.New("only root users can start the service")
@@ -46,7 +49,7 @@ func (this *ServiceManager) Start() error {
return exec.Command("service", teaconst.ProcessName, "start").Start()
}
// 删除服务
// Uninstall 删除服务
func (this *ServiceManager) Uninstall() error {
if os.Getgid() != 0 {
return errors.New("only root users can uninstall the service")
@@ -108,10 +111,10 @@ func (this *ServiceManager) installInitService(exePath string, args []string) er
// install systemd service
func (this *ServiceManager) installSystemdService(systemd, exePath string, args []string) error {
shortName := teaconst.SystemdServiceName
longName := "GoEdge Node" // TODO 将来可以修改
var shortName = teaconst.SystemdServiceName
var longName = "GoEdge Node" // TODO 将来可以修改
desc := `# Provides: ` + shortName + `
var desc = `# Provides: ` + shortName + `
# Required-Start: $all
# Required-Stop:
# Default-Start: 2 3 4 5
@@ -142,12 +145,17 @@ WantedBy=multi-user.target`
}
// stop current systemd service if running
exec.Command(systemd, "stop", shortName+".service")
executils.NewTimeoutCmd(30*time.Second, systemd, "stop", shortName+".service")
// reload
exec.Command(systemd, "daemon-reload")
executils.NewTimeoutCmd(30*time.Second, systemd, "daemon-reload")
// enable
cmd := exec.Command(systemd, "enable", shortName+".service")
return cmd.Run()
var cmd = executils.NewTimeoutCmd(30*time.Second, systemd, "enable", shortName+".service")
cmd.WithStderr()
err = cmd.Run()
if err != nil {
return errors.New(err.Error() + ": " + cmd.Stderr())
}
return nil
}

View File

@@ -5,6 +5,7 @@ package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"time"
@@ -50,5 +51,14 @@ func CaptchaDeleteCacheKey(req requests.Request) {
// CaptchaCacheKey 获取Captcha缓存Key
func CaptchaCacheKey(req requests.Request, pageCode CaptchaPageCode) string {
return "CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId()) + ":" + req.WAFRaw().URL.String()
var requestPath = req.WAFRaw().URL.Path
if req.WAFRaw().URL.Path == CaptchaPath {
m, err := utils.SimpleDecryptMap(req.WAFRaw().URL.Query().Get("info"))
if err == nil && m != nil {
requestPath = m.GetString("url")
}
}
return "CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId()) + ":" + requestPath
}

View File

@@ -117,6 +117,11 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Req
msgPrompt = "请输入上面的验证码"
msgButtonTitle = "提交验证"
msgRequestId = "请求ID"
case "zh-TW":
msgTitle = "身份驗證"
msgPrompt = "請輸入上面的驗證碼"
msgButtonTitle = "提交驗證"
msgRequestId = "請求ID"
default:
msgTitle = "Verify Yourself"
msgPrompt = "Input verify code above:"

View File

@@ -528,7 +528,7 @@ func (this *Rule) Test(value interface{}) bool {
if ip == nil {
return false
}
return this.isIP && bytes.Compare(this.ipValue, ip) == 0
return this.isIP && bytes.Equal(this.ipValue, ip)
case RuleOperatorGtIP:
ip := net.ParseIP(types.String(value))
if ip == nil {