Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
927425149e | ||
|
|
5ce1aab92c | ||
|
|
195742bb26 | ||
|
|
006cc2912d | ||
|
|
2d4ba90c3b | ||
|
|
a2e6aaaa18 | ||
|
|
8e68da7725 | ||
|
|
7abb84c880 | ||
|
|
a17878f5b2 | ||
|
|
8a8881ac47 | ||
|
|
c567404b7a | ||
|
|
b220b0f48e | ||
|
|
9609c90d75 | ||
|
|
2c3c32af5b | ||
|
|
b4a4b2e9b1 | ||
|
|
c42ff1e1e9 | ||
|
|
9fed1141c2 | ||
|
|
e87f031293 | ||
|
|
c4bac7f43c | ||
|
|
47818f972e | ||
|
|
218a0300c5 | ||
|
|
63f6c4177f | ||
|
|
1830c22a31 | ||
|
|
18611e8a7c | ||
|
|
c45f7adf04 | ||
|
|
1a200918a8 | ||
|
|
b942bb776e | ||
|
|
5cf84efccd | ||
|
|
ebb6ebd10c | ||
|
|
42d0d63cf4 | ||
|
|
96f8f7e925 | ||
|
|
e7e7214d58 | ||
|
|
ade979a725 | ||
|
|
60a8de13e7 | ||
|
|
9fa24bed0a | ||
|
|
87bc1a7e03 | ||
|
|
1a05f56149 | ||
|
|
f88db576e1 | ||
|
|
dc3f26ea1a | ||
|
|
6fc30144f7 | ||
|
|
25b0b98bd4 | ||
|
|
27b5817d5e | ||
|
|
dcb61dfd33 | ||
|
|
bbcfdbbf5e | ||
|
|
b2a1bef08f | ||
|
|
2b18b5c2ca |
@@ -25,7 +25,7 @@ func main() {
|
||||
Product(teaconst.ProductName).
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog]").
|
||||
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc]").
|
||||
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove] IP")
|
||||
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove|ip.close] IP")
|
||||
|
||||
app.On("test", func() {
|
||||
err := nodes.NewNode().Test()
|
||||
@@ -241,6 +241,38 @@ func main() {
|
||||
}
|
||||
}
|
||||
})
|
||||
app.On("ip.close", func() {
|
||||
var args = os.Args[2:]
|
||||
if len(args) == 0 {
|
||||
fmt.Println("Usage: edge-node ip.close IP")
|
||||
return
|
||||
}
|
||||
var ip = args[0]
|
||||
if len(net.ParseIP(ip)) == 0 {
|
||||
fmt.Println("IP '" + ip + "' is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("close ip '" + ip)
|
||||
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
reply, err := sock.Send(&gosock.Command{
|
||||
Code: "closeIP",
|
||||
Params: map[string]any{
|
||||
"ip": ip,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
} else {
|
||||
var errString = maps.NewMap(reply.Params).GetString("error")
|
||||
if len(errString) > 0 {
|
||||
fmt.Println("[ERROR]" + errString)
|
||||
} else {
|
||||
fmt.Println("ok")
|
||||
}
|
||||
}
|
||||
})
|
||||
app.On("ip.remove", func() {
|
||||
var args = os.Args[2:]
|
||||
if len(args) == 0 {
|
||||
|
||||
@@ -180,6 +180,9 @@ func (this *FileListDB) Init() error {
|
||||
}
|
||||
|
||||
this.selectHashListStmt, err = this.readDB.Prepare(`SELECT "id", "hash" FROM "` + this.itemsTableName + `" WHERE id>:id ORDER BY id ASC LIMIT 2000`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteByHashSQL = `DELETE FROM "` + this.itemsTableName + `" WHERE "hash"=?`
|
||||
this.deleteByHashStmt, err = this.writeDB.Prepare(this.deleteByHashSQL)
|
||||
|
||||
@@ -129,6 +129,9 @@ func (this *OpenFileCache) Close(filename string) {
|
||||
|
||||
pool, ok := this.poolMap[filename]
|
||||
if ok {
|
||||
// 设置关闭状态
|
||||
pool.SetClosing()
|
||||
|
||||
delete(this.poolMap, filename)
|
||||
this.poolList.Remove(pool.linkItem)
|
||||
_ = this.watcher.Remove(filename)
|
||||
|
||||
@@ -12,6 +12,7 @@ type OpenFilePool struct {
|
||||
linkItem *linkedlist.Item
|
||||
filename string
|
||||
version int64
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewOpenFilePool(filename string) *OpenFilePool {
|
||||
@@ -29,26 +30,43 @@ func (this *OpenFilePool) Filename() string {
|
||||
}
|
||||
|
||||
func (this *OpenFilePool) Get() (*OpenFile, bool) {
|
||||
// 如果已经关闭,直接返回
|
||||
if this.isClosed {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
select {
|
||||
case file := <-this.c:
|
||||
err := file.SeekStart()
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return nil, true
|
||||
}
|
||||
file.version = this.version
|
||||
if file != nil {
|
||||
err := file.SeekStart()
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return nil, true
|
||||
}
|
||||
file.version = this.version
|
||||
|
||||
return file, true
|
||||
return file, true
|
||||
}
|
||||
return nil, false
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (this *OpenFilePool) Put(file *OpenFile) bool {
|
||||
// 如果已关闭,则不接受新的文件
|
||||
if this.isClosed {
|
||||
_ = file.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查文件版本号
|
||||
if this.version > 0 && file.version > 0 && file.version != this.version {
|
||||
_ = file.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
// 加入Pool
|
||||
select {
|
||||
case this.c <- file:
|
||||
return true
|
||||
@@ -63,14 +81,18 @@ func (this *OpenFilePool) Len() int {
|
||||
return len(this.c)
|
||||
}
|
||||
|
||||
func (this *OpenFilePool) SetClosing() {
|
||||
this.isClosed = true
|
||||
}
|
||||
|
||||
func (this *OpenFilePool) Close() {
|
||||
Loop:
|
||||
this.isClosed = true
|
||||
for {
|
||||
select {
|
||||
case file := <-this.c:
|
||||
_ = file.Close()
|
||||
default:
|
||||
break Loop
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,6 +215,10 @@ func (this *FileReader) ReadHeader(buf []byte, callback ReaderFunc) error {
|
||||
}
|
||||
|
||||
func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
|
||||
if this.bodySize == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var isOk = false
|
||||
|
||||
defer func() {
|
||||
@@ -257,6 +261,12 @@ func (this *FileReader) ReadBody(buf []byte, callback ReaderFunc) error {
|
||||
}
|
||||
|
||||
func (this *FileReader) Read(buf []byte) (n int, err error) {
|
||||
if this.bodySize == 0 {
|
||||
n = 0
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
n, err = this.fp.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
_ = this.discard()
|
||||
|
||||
@@ -710,9 +710,6 @@ func (this *FileStorage) Delete(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 先尝试内存缓存
|
||||
this.runMemoryStorageSafety(func(memoryStorage *MemoryStorage) {
|
||||
_ = memoryStorage.Delete(key)
|
||||
@@ -733,9 +730,6 @@ func (this *FileStorage) Delete(key string) error {
|
||||
|
||||
// Stat 统计
|
||||
func (this *FileStorage) Stat() (*Stat, error) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
return this.list.Stat(func(hash string) bool {
|
||||
return true
|
||||
})
|
||||
@@ -767,57 +761,61 @@ func (this *FileStorage) CleanAll() error {
|
||||
}
|
||||
}
|
||||
|
||||
var dirNameReg = regexp.MustCompile(`^[0-9a-f]{2}$`)
|
||||
for _, rootDir := range rootDirs {
|
||||
var dir = rootDir + "/p" + types.String(this.policy.Id)
|
||||
fp, err := os.Open(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
err = func(dir string) error {
|
||||
fp, err := os.Open(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
|
||||
stat, err := fp.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stat, err := fp.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 改成待删除
|
||||
subDirs, err := fp.Readdir(-1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, info := range subDirs {
|
||||
subDir := info.Name()
|
||||
|
||||
// 检查目录名
|
||||
if !dirNameReg.MatchString(subDir) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 修改目录名
|
||||
tmpDir := dir + "/" + subDir + "-deleted"
|
||||
err = os.Rename(dir+"/"+subDir, tmpDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 重新遍历待删除
|
||||
goman.New(func() {
|
||||
err = this.cleanDeletedDirs(dir)
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "delete '*-deleted' dirs failed: "+err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
if !stat.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 改成待删除
|
||||
subDirs, err := fp.Readdir(-1)
|
||||
}(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, info := range subDirs {
|
||||
subDir := info.Name()
|
||||
|
||||
// 检查目录名
|
||||
ok, err := regexp.MatchString(`^[0-9a-f]{2}$`, subDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 修改目录名
|
||||
tmpDir := dir + "/" + subDir + "-deleted"
|
||||
err = os.Rename(dir+"/"+subDir, tmpDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 重新遍历待删除
|
||||
goman.New(func() {
|
||||
err = this.cleanDeletedDirs(dir)
|
||||
if err != nil {
|
||||
remotelogs.Warn("CACHE", "delete '*-deleted' dirs failed: "+err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -830,9 +828,6 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 先尝试内存缓存
|
||||
this.runMemoryStorageSafety(func(memoryStorage *MemoryStorage) {
|
||||
_ = memoryStorage.Purge(keys, urlType)
|
||||
@@ -1218,9 +1213,12 @@ func (this *FileStorage) hotLoop() {
|
||||
}
|
||||
|
||||
err = reader.ReadBody(buf, func(n int) (goNext bool, err error) {
|
||||
_, err = writer.Write(buf[:n])
|
||||
if err == nil {
|
||||
goNext = true
|
||||
goNext = true
|
||||
if n > 0 {
|
||||
_, err = writer.Write(buf[:n])
|
||||
if err != nil {
|
||||
goNext = false
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
@@ -14,15 +14,22 @@ import (
|
||||
)
|
||||
|
||||
func TestMemoryStorage_OpenWriter(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
|
||||
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, -1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.WriteHeader([]byte("Header"))
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
_, _ = writer.Write([]byte(", World"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(storage.valuesMap)
|
||||
|
||||
{
|
||||
@@ -30,6 +37,7 @@ func TestMemoryStorage_OpenWriter(t *testing.T) {
|
||||
if err != nil {
|
||||
if err == ErrNotFound {
|
||||
t.Log("not found: abc")
|
||||
return
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -102,13 +110,17 @@ func TestMemoryStorage_OpenReaderLock(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Delete(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", time.Now().Unix()+60, 200, -1, -1, -1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(storage.valuesMap))
|
||||
}
|
||||
{
|
||||
@@ -117,6 +129,10 @@ func TestMemoryStorage_Delete(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(storage.valuesMap))
|
||||
}
|
||||
_ = storage.Delete("abc1")
|
||||
@@ -124,7 +140,7 @@ func TestMemoryStorage_Delete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Stat(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
expiredAt := time.Now().Unix() + 60
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, -1, false)
|
||||
@@ -132,6 +148,10 @@ func TestMemoryStorage_Stat(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(storage.valuesMap))
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc",
|
||||
@@ -145,6 +165,10 @@ func TestMemoryStorage_Stat(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(storage.valuesMap))
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc1",
|
||||
@@ -161,14 +185,18 @@ func TestMemoryStorage_Stat(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_CleanAll(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
expiredAt := time.Now().Unix() + 60
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
var expiredAt = time.Now().Unix() + 60
|
||||
{
|
||||
writer, err := storage.OpenWriter("abc", expiredAt, 200, -1, -1, -1, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc",
|
||||
BodySize: 5,
|
||||
@@ -181,6 +209,10 @@ func TestMemoryStorage_CleanAll(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc1",
|
||||
BodySize: 5,
|
||||
@@ -204,6 +236,10 @@ func TestMemoryStorage_Purge(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc",
|
||||
BodySize: 5,
|
||||
@@ -216,6 +252,10 @@ func TestMemoryStorage_Purge(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage.AddToList(&Item{
|
||||
Key: "abc1",
|
||||
BodySize: 5,
|
||||
@@ -231,7 +271,7 @@ func TestMemoryStorage_Purge(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Expire(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{
|
||||
MemoryAutoPurgeInterval: 5,
|
||||
}, nil)
|
||||
err := storage.Init()
|
||||
@@ -247,6 +287,10 @@ func TestMemoryStorage_Expire(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _ = writer.Write([]byte("Hello"))
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
storage.AddToList(&Item{
|
||||
Key: key,
|
||||
BodySize: 5,
|
||||
@@ -257,7 +301,7 @@ func TestMemoryStorage_Expire(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemoryStorage_Locker(t *testing.T) {
|
||||
storage := NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
var storage = NewMemoryStorage(&serverconfigs.HTTPCachePolicy{}, nil)
|
||||
err := storage.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
7
internal/conns/linger.go
Normal file
7
internal/conns/linger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package conns
|
||||
|
||||
type LingerConn interface {
|
||||
SetLinger(sec int) error
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
var SharedMap = NewMap()
|
||||
|
||||
type Map struct {
|
||||
m map[string]map[int]net.Conn // ip => { port => Conn }
|
||||
m map[string]map[int]net.Conn // ip => { port => Conn }
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
@@ -37,9 +37,7 @@ func (this *Map) Add(conn net.Conn) {
|
||||
defer this.locker.Unlock()
|
||||
connMap, ok := this.m[ip]
|
||||
if !ok {
|
||||
this.m[ip] = map[int]net.Conn{
|
||||
port: conn,
|
||||
}
|
||||
this.m[ip] = map[int]net.Conn{port: conn}
|
||||
} else {
|
||||
connMap[port] = conn
|
||||
}
|
||||
@@ -96,6 +94,13 @@ func (this *Map) CloseIPConns(ip string) {
|
||||
|
||||
if ok {
|
||||
for _, conn := range conns {
|
||||
// 设置Linger
|
||||
lingerConn, isLingerConn := conn.(LingerConn)
|
||||
if isLingerConn {
|
||||
_ = lingerConn.SetLinger(0)
|
||||
}
|
||||
|
||||
// 关闭
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
@@ -109,9 +114,10 @@ func (this *Map) AllConns() []net.Conn {
|
||||
|
||||
var result = []net.Conn{}
|
||||
for _, m := range this.m {
|
||||
for _, conn := range m {
|
||||
result = append(result, conn)
|
||||
for _, connInfo := range m {
|
||||
result = append(result, connInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.5.8"
|
||||
Version = "0.6.1"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
@@ -25,17 +26,30 @@ import (
|
||||
type ClientConn struct {
|
||||
BaseClientConn
|
||||
|
||||
isTLS bool
|
||||
hasDeadline bool
|
||||
hasRead bool
|
||||
createdAt int64
|
||||
|
||||
isTLS bool
|
||||
isHTTP bool
|
||||
hasRead bool
|
||||
|
||||
isLO bool // 是否为环路
|
||||
isInAllowList bool
|
||||
|
||||
hasResetSYNFlood bool
|
||||
|
||||
lastReadAt int64
|
||||
lastWriteAt int64
|
||||
lastErr error
|
||||
|
||||
readDeadlineTime int64
|
||||
isShortReading bool // reading header or tls handshake
|
||||
|
||||
isDebugging bool
|
||||
autoReadTimeout bool
|
||||
autoWriteTimeout bool
|
||||
}
|
||||
|
||||
func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList bool) net.Conn {
|
||||
func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool) net.Conn {
|
||||
// 是否为环路
|
||||
var remoteAddr = rawConn.RemoteAddr().String()
|
||||
var isLO = strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:")
|
||||
@@ -43,11 +57,21 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
|
||||
var conn = &ClientConn{
|
||||
BaseClientConn: BaseClientConn{rawConn: rawConn},
|
||||
isTLS: isTLS,
|
||||
isHTTP: isHTTP,
|
||||
isLO: isLO,
|
||||
isInAllowList: isInAllowList,
|
||||
createdAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if quickClose {
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil {
|
||||
var performanceConfig = globalServerConfig.Performance
|
||||
conn.isDebugging = performanceConfig.Debug
|
||||
conn.autoReadTimeout = performanceConfig.AutoReadTimeout
|
||||
conn.autoWriteTimeout = performanceConfig.AutoWriteTimeout
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
// TODO 可以在配置中设置此值
|
||||
_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
@@ -59,6 +83,16 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool, isInAllowList
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
if this.isDebugging {
|
||||
this.lastReadAt = time.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = errors.New("read error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 环路直接读取
|
||||
if this.isLO {
|
||||
n, err = this.rawConn.Read(b)
|
||||
@@ -68,34 +102,29 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// TLS
|
||||
if this.isTLS {
|
||||
if !this.hasDeadline {
|
||||
_ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置
|
||||
this.hasDeadline = true
|
||||
defer func() {
|
||||
_ = this.rawConn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
}
|
||||
// 设置读超时时间
|
||||
if this.isHTTP && !this.isWebsocket && !this.isShortReading && this.autoReadTimeout {
|
||||
this.setHTTPReadTimeout()
|
||||
}
|
||||
|
||||
// 开始读取
|
||||
n, err = this.rawConn.Read(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
|
||||
if !this.hasRead {
|
||||
this.hasRead = true
|
||||
}
|
||||
this.hasRead = true
|
||||
}
|
||||
|
||||
// 检测是否为握手错误
|
||||
var isHandshakeError = err != nil && os.IsTimeout(err) && !this.hasRead
|
||||
if isHandshakeError {
|
||||
// 检测是否为超时错误
|
||||
var isTimeout = err != nil && os.IsTimeout(err)
|
||||
var isHandshakeError = isTimeout && !this.hasRead
|
||||
if isTimeout {
|
||||
_ = this.SetLinger(0)
|
||||
} else {
|
||||
_ = this.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
|
||||
// 忽略白名单和局域网
|
||||
if !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
|
||||
if this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
|
||||
// SYN Flood检测
|
||||
if this.serverId == 0 || !this.hasResetSYNFlood {
|
||||
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
|
||||
@@ -114,6 +143,32 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
if this.isDebugging {
|
||||
this.lastWriteAt = time.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = errors.New("write error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 设置写超时时间
|
||||
if this.autoWriteTimeout {
|
||||
// TODO L2 -> L1 写入时不限制时间
|
||||
var timeoutSeconds = len(b) / 1024
|
||||
if timeoutSeconds < 3 {
|
||||
timeoutSeconds = 3
|
||||
}
|
||||
_ = this.rawConn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutSeconds) * time.Second)) // TODO 时间可以设置
|
||||
}
|
||||
|
||||
// 延长读超时时间
|
||||
if this.isHTTP && !this.isWebsocket && this.autoReadTimeout {
|
||||
this.setHTTPReadTimeout()
|
||||
}
|
||||
|
||||
// 开始写入
|
||||
n, err = this.rawConn.Write(b)
|
||||
if n > 0 {
|
||||
// 统计当前服务带宽
|
||||
@@ -125,6 +180,17 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是写入超时,则立即关闭连接
|
||||
if err != nil && os.IsTimeout(err) {
|
||||
// TODO 考虑对多次慢连接的IP做出惩罚
|
||||
conn, ok := this.rawConn.(LingerConn)
|
||||
if ok {
|
||||
_ = conn.SetLinger(0)
|
||||
}
|
||||
|
||||
_ = this.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -156,6 +222,26 @@ func (this *ClientConn) SetDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func (this *ClientConn) SetReadDeadline(t time.Time) error {
|
||||
// 如果开启了HTTP自动读超时选项,则自动控制超时时间
|
||||
if this.isHTTP && !this.isWebsocket && this.autoReadTimeout {
|
||||
this.isShortReading = false
|
||||
|
||||
var unixTime = t.Unix()
|
||||
if unixTime < 10 {
|
||||
return nil
|
||||
}
|
||||
if unixTime == this.readDeadlineTime {
|
||||
return nil
|
||||
}
|
||||
this.readDeadlineTime = unixTime
|
||||
var seconds = -time.Since(t)
|
||||
if seconds <= 0 || seconds > HTTPIdleTimeout {
|
||||
return nil
|
||||
}
|
||||
if seconds < HTTPIdleTimeout-1*time.Second {
|
||||
this.isShortReading = true
|
||||
}
|
||||
}
|
||||
return this.rawConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
@@ -163,6 +249,22 @@ func (this *ClientConn) SetWriteDeadline(t time.Time) error {
|
||||
return this.rawConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientConn) CreatedAt() int64 {
|
||||
return this.createdAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastReadAt() int64 {
|
||||
return this.lastReadAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastWriteAt() int64 {
|
||||
return this.lastWriteAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastErr() error {
|
||||
return this.lastErr
|
||||
}
|
||||
|
||||
func (this *ClientConn) resetSYNFlood() {
|
||||
ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP())
|
||||
}
|
||||
@@ -194,3 +296,8 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置读超时时间
|
||||
func (this *ClientConn) setHTTPReadTimeout() {
|
||||
_ = this.SetReadDeadline(time.Now().Add(HTTPIdleTimeout))
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ type BaseClientConn struct {
|
||||
remoteAddr string
|
||||
hasLimit bool
|
||||
|
||||
isWebsocket bool
|
||||
|
||||
isClosed bool
|
||||
|
||||
rawIP string
|
||||
@@ -122,3 +124,7 @@ func (this *BaseClientConn) SetLinger(seconds int) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *BaseClientConn) SetIsWebsocket(isWebsocket bool) {
|
||||
this.isWebsocket = isWebsocket
|
||||
}
|
||||
|
||||
@@ -23,4 +23,7 @@ type ClientConnInterface interface {
|
||||
|
||||
// UserId 获取当前连接所属服务的用户ID
|
||||
UserId() int64
|
||||
|
||||
// SetIsWebsocket 设置是否为Websocket
|
||||
SetIsWebsocket(isWebsocket bool)
|
||||
}
|
||||
|
||||
@@ -14,14 +14,14 @@ import (
|
||||
// ClientListener 客户端网络监听
|
||||
type ClientListener struct {
|
||||
rawListener net.Listener
|
||||
isHTTP bool
|
||||
isTLS bool
|
||||
quickClose bool
|
||||
}
|
||||
|
||||
func NewClientListener(listener net.Listener, quickClose bool) *ClientListener {
|
||||
func NewClientListener(listener net.Listener, isHTTP bool) *ClientListener {
|
||||
return &ClientListener{
|
||||
rawListener: listener,
|
||||
quickClose: quickClose,
|
||||
isHTTP: isHTTP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose, isInAllowList), nil
|
||||
return NewClientConn(conn, this.isHTTP, this.isTLS, isInAllowList), nil
|
||||
}
|
||||
|
||||
func (this *ClientListener) Close() error {
|
||||
|
||||
@@ -95,11 +95,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
numberCPU = 8
|
||||
}
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = numberCPU * 32
|
||||
maxConnections = numberCPU * 64
|
||||
}
|
||||
|
||||
if idleConns <= 0 {
|
||||
idleConns = numberCPU * 8
|
||||
idleConns = numberCPU * 16
|
||||
}
|
||||
|
||||
// 可以判断为Ln节点请求
|
||||
|
||||
@@ -237,6 +237,14 @@ func (this *HTTPRequest) Do() {
|
||||
}
|
||||
}
|
||||
|
||||
// UA名单
|
||||
if !this.isSubRequest && this.web.UserAgent != nil && this.web.UserAgent.IsOn {
|
||||
if this.doCheckUserAgent() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 访问控制
|
||||
if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn {
|
||||
if this.doAuth() {
|
||||
@@ -526,6 +534,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
this.web.Referers = web.Referers
|
||||
}
|
||||
|
||||
// user agent
|
||||
if web.UserAgent != nil && (web.UserAgent.IsPrior || isTop) {
|
||||
this.web.UserAgent = web.UserAgent
|
||||
}
|
||||
|
||||
// request limit
|
||||
if web.RequestLimit != nil && (web.RequestLimit.IsPrior || isTop) {
|
||||
this.web.RequestLimit = web.RequestLimit
|
||||
@@ -1133,6 +1146,8 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
|
||||
|
||||
// 获取请求的客户端地址列表
|
||||
func (this *HTTPRequest) requestRemoteAddrs() (result []string) {
|
||||
result = append(result, this.requestRemoteAddr(true))
|
||||
|
||||
// X-Forwarded-For
|
||||
var forwardedFor = this.RawReq.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
@@ -1554,7 +1569,7 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
}
|
||||
|
||||
// 是否已删除
|
||||
if this.web.ResponseHeaderPolicy.ContainsDeletedHeader(header.Name) {
|
||||
if this.web.RequestHeaderPolicy.ContainsDeletedHeader(header.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1692,6 +1707,36 @@ func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, stat
|
||||
responseHeader[header.Name] = []string{headerValue}
|
||||
}
|
||||
}
|
||||
|
||||
// CORS
|
||||
if this.web.ResponseHeaderPolicy.CORS != nil && this.web.ResponseHeaderPolicy.CORS.IsOn {
|
||||
var corsConfig = this.web.ResponseHeaderPolicy.CORS
|
||||
|
||||
// Allow-Origin
|
||||
if len(corsConfig.AllowOrigin) == 0 {
|
||||
var origin = this.RawReq.Header.Get("Origin")
|
||||
if len(origin) > 0 {
|
||||
responseHeader.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
} else {
|
||||
responseHeader.Set("Access-Control-Allow-Origin", corsConfig.AllowOrigin)
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
if len(corsConfig.AllowMethods) == 0 {
|
||||
responseHeader.Set("Access-Control-Allow-Methods", "PUT, GET, POST, DELETE, HEAD, OPTIONS")
|
||||
} else {
|
||||
responseHeader.Set("Access-Control-Allow-Methods", strings.Join(corsConfig.AllowMethods, ", "))
|
||||
}
|
||||
|
||||
// Max-Age
|
||||
if corsConfig.MaxAge > 0 {
|
||||
responseHeader.Set("Access-Control-Max-Age", types.String(corsConfig.MaxAge))
|
||||
}
|
||||
|
||||
// Allow-Credentials
|
||||
responseHeader.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
// HSTS
|
||||
|
||||
@@ -146,6 +146,13 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
|
||||
// 参数
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -12,5 +12,5 @@ func (this *HTTPRequest) doStat() {
|
||||
|
||||
// 内置的统计
|
||||
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
|
||||
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"))
|
||||
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"), this.remoteAddr)
|
||||
}
|
||||
|
||||
24
internal/nodes/http_request_user_agent.go
Normal file
24
internal/nodes/http_request_user_agent.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doCheckUserAgent() (shouldStop bool) {
|
||||
if this.web.UserAgent == nil {
|
||||
return
|
||||
}
|
||||
|
||||
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
|
||||
|
||||
if !this.web.UserAgent.AllowRequest(this.RawReq) {
|
||||
this.tags = append(this.tags, "userAgentCheck")
|
||||
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
|
||||
this.writeCode(http.StatusForbidden, "The User-Agent has been blocked.", "当前访问已被UA名单拦截。")
|
||||
return true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -70,6 +70,13 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
|
||||
this.RawReq.Header.Set("Origin", newRequestOrigin)
|
||||
}
|
||||
|
||||
// 获取当前连接
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 连接源站
|
||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
|
||||
if err != nil {
|
||||
@@ -102,6 +109,11 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
|
||||
return
|
||||
}
|
||||
|
||||
requestClientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
requestClientConn.SetIsWebsocket(true)
|
||||
}
|
||||
|
||||
clientConn, _, err := this.writer.Hijack()
|
||||
if err != nil || clientConn == nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
|
||||
|
||||
@@ -132,7 +132,7 @@ func (this *HTTPWriter) Prepare(resp *http.Response, size int64, status int, ena
|
||||
this.req.web.RequestLimit != nil &&
|
||||
this.req.web.RequestLimit.IsOn &&
|
||||
this.req.web.RequestLimit.OutBandwidthPerConnBytes() > 0 {
|
||||
this.writer = writers.NewRateLimitWriter(this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes())
|
||||
this.writer = writers.NewRateLimitWriter(this.req.RawReq.Context(), this.writer, this.req.web.RequestLimit.OutBandwidthPerConnBytes())
|
||||
}
|
||||
|
||||
return
|
||||
@@ -584,6 +584,11 @@ func (this *HTTPWriter) PrepareCompression(resp *http.Response, size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
// 分区内容不压缩,防止读取失败
|
||||
if !this.compressionConfig.EnablePartialContent && this.StatusCode() == http.StatusPartialContent {
|
||||
return
|
||||
}
|
||||
|
||||
if this.compressionConfig.Level <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
Certificates: nil,
|
||||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||||
tlsPolicy, _, err := this.matchSSL(clientInfo.ServerName)
|
||||
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return tlsPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
tlsPolicy, cert, err := this.matchSSL(clientInfo.ServerName)
|
||||
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,3 +182,18 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 从Hello信息中获取服务名称
|
||||
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
|
||||
var serverName = clientInfo.ServerName
|
||||
if len(serverName) == 0 {
|
||||
var localAddr = clientInfo.Conn.LocalAddr()
|
||||
if localAddr != nil {
|
||||
tcpAddr, ok := localAddr.(*net.TCPAddr)
|
||||
if ok {
|
||||
serverName = tcpAddr.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return serverName
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestBaseListener_FindServer(t *testing.T) {
|
||||
sharedNodeConfig = &nodeconfigs.NodeConfig{}
|
||||
|
||||
var listener = &BaseListener{}
|
||||
listener.Group = &serverconfigs.ServerAddressGroup{}
|
||||
listener.Group = serverconfigs.NewServerAddressGroup("https://*:443")
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var server = &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
var httpErrorLogger = log.New(io.Discard, "", 0)
|
||||
|
||||
const HTTPIdleTimeout = 75 * time.Second
|
||||
|
||||
type contextKey struct {
|
||||
key string
|
||||
}
|
||||
@@ -43,16 +45,12 @@ func (this *HTTPListener) Serve() error {
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.addr,
|
||||
Handler: this,
|
||||
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
|
||||
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
|
||||
WriteTimeout: 2 * time.Hour, // TODO 改成可以配置
|
||||
IdleTimeout: 75 * time.Second, // TODO 改成可以配置
|
||||
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
|
||||
IdleTimeout: HTTPIdleTimeout, // TODO 改成可以配置
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
case http.StateActive, http.StateIdle, http.StateHijacked:
|
||||
// Nothing to do
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
}
|
||||
@@ -116,8 +114,14 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
|
||||
// ServerHTTP 处理HTTP请求
|
||||
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
// 不支持Connect
|
||||
if rawReq.Method == http.MethodConnect {
|
||||
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 域名
|
||||
var reqHost = strings.TrimRight(rawReq.Host, ".")
|
||||
var reqHost = strings.ToLower(strings.TrimRight(rawReq.Host, "."))
|
||||
|
||||
// TLS域名
|
||||
if this.isIP(reqHost) {
|
||||
|
||||
@@ -25,7 +25,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
|
||||
_ "github.com/TeaOSLab/EdgeNode/internal/utils/agents" // 引入Agent管理器
|
||||
_ "github.com/TeaOSLab/EdgeNode/internal/utils/clock" // 触发时钟更新
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/andybalholm/brotli"
|
||||
@@ -879,16 +880,56 @@ func (this *Node) listenSock() error {
|
||||
},
|
||||
})
|
||||
case "conns":
|
||||
var addrs = []string{}
|
||||
var connMaps = []maps.Map{}
|
||||
var connMap = conns.SharedMap.AllConns()
|
||||
for _, conn := range connMap {
|
||||
addrs = append(addrs, conn.RemoteAddr().String())
|
||||
var createdAt int64
|
||||
var lastReadAt int64
|
||||
var lastWriteAt int64
|
||||
var lastErrString = ""
|
||||
clientConn, ok := conn.(*ClientConn)
|
||||
if ok {
|
||||
createdAt = clientConn.CreatedAt()
|
||||
lastReadAt = clientConn.LastReadAt()
|
||||
lastWriteAt = clientConn.LastWriteAt()
|
||||
|
||||
var lastErr = clientConn.LastErr()
|
||||
if lastErr != nil {
|
||||
lastErrString = lastErr.Error()
|
||||
}
|
||||
}
|
||||
var age int64 = -1
|
||||
var lastReadAge int64 = -1
|
||||
var lastWriteAge int64 = -1
|
||||
var currentTime = time.Now().Unix()
|
||||
if createdAt > 0 {
|
||||
age = currentTime - createdAt
|
||||
}
|
||||
if lastReadAt > 0 {
|
||||
lastReadAge = currentTime - lastReadAt
|
||||
}
|
||||
if lastWriteAt > 0 {
|
||||
lastWriteAge = currentTime - lastWriteAt
|
||||
}
|
||||
|
||||
connMaps = append(connMaps, maps.Map{
|
||||
"addr": conn.RemoteAddr().String(),
|
||||
"age": age,
|
||||
"readAge": lastReadAge,
|
||||
"writeAge": lastWriteAge,
|
||||
"lastErr": lastErrString,
|
||||
})
|
||||
}
|
||||
sort.Slice(connMaps, func(i, j int) bool {
|
||||
var m1 = connMaps[i]
|
||||
var m2 = connMaps[j]
|
||||
return m1.GetInt64("age") < m2.GetInt64("age")
|
||||
})
|
||||
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]interface{}{
|
||||
"addrs": addrs,
|
||||
"total": len(addrs),
|
||||
"conns": connMaps,
|
||||
"total": len(connMaps),
|
||||
},
|
||||
})
|
||||
case "dropIP":
|
||||
@@ -920,6 +961,11 @@ func (this *Node) listenSock() error {
|
||||
} else {
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
case "closeIP":
|
||||
var m = maps.NewMap(cmd.Params)
|
||||
var ip = m.GetString("ip")
|
||||
conns.SharedMap.CloseIPConns(ip)
|
||||
_ = cmd.ReplyOk()
|
||||
case "removeIP":
|
||||
var m = maps.NewMap(cmd.Params)
|
||||
var ip = m.GetString("ip")
|
||||
@@ -987,50 +1033,48 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig, reloadAll bool) {
|
||||
nodeconfigs.ResetNodeConfig(config)
|
||||
sharedNodeConfig = config
|
||||
|
||||
// 不需要每次都全部重新加载
|
||||
if !reloadAll {
|
||||
return
|
||||
}
|
||||
|
||||
// 缓存策略
|
||||
var subDirs = config.CacheDiskSubDirs
|
||||
for _, subDir := range subDirs {
|
||||
subDir.Path = filepath.Clean(subDir.Path)
|
||||
}
|
||||
if len(subDirs) > 0 {
|
||||
sort.Slice(subDirs, func(i, j int) bool {
|
||||
return subDirs[i].Path < subDirs[j].Path
|
||||
})
|
||||
}
|
||||
|
||||
var cachePoliciesChanged = !jsonutils.Equal(caches.SharedManager.MaxDiskCapacity, config.MaxCacheDiskCapacity) ||
|
||||
!jsonutils.Equal(caches.SharedManager.MaxMemoryCapacity, config.MaxCacheMemoryCapacity) ||
|
||||
!jsonutils.Equal(caches.SharedManager.MainDiskDir, config.CacheDiskDir) ||
|
||||
!jsonutils.Equal(caches.SharedManager.SubDiskDirs, subDirs) ||
|
||||
!jsonutils.Equal(this.oldHTTPCachePolicies, config.HTTPCachePolicies)
|
||||
|
||||
caches.SharedManager.MaxDiskCapacity = config.MaxCacheDiskCapacity
|
||||
caches.SharedManager.MaxMemoryCapacity = config.MaxCacheMemoryCapacity
|
||||
caches.SharedManager.MainDiskDir = config.CacheDiskDir
|
||||
caches.SharedManager.SubDiskDirs = subDirs
|
||||
|
||||
if cachePoliciesChanged {
|
||||
// copy
|
||||
this.oldHTTPCachePolicies = []*serverconfigs.HTTPCachePolicy{}
|
||||
err := jsonutils.Copy(&this.oldHTTPCachePolicies, config.HTTPCachePolicies)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy HTTPCachePolicies failed: "+err.Error())
|
||||
if reloadAll {
|
||||
// 缓存策略
|
||||
var subDirs = config.CacheDiskSubDirs
|
||||
for _, subDir := range subDirs {
|
||||
subDir.Path = filepath.Clean(subDir.Path)
|
||||
}
|
||||
if len(subDirs) > 0 {
|
||||
sort.Slice(subDirs, func(i, j int) bool {
|
||||
return subDirs[i].Path < subDirs[j].Path
|
||||
})
|
||||
}
|
||||
|
||||
// update
|
||||
if len(config.HTTPCachePolicies) > 0 {
|
||||
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
|
||||
} else {
|
||||
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
|
||||
var cachePoliciesChanged = !jsonutils.Equal(caches.SharedManager.MaxDiskCapacity, config.MaxCacheDiskCapacity) ||
|
||||
!jsonutils.Equal(caches.SharedManager.MaxMemoryCapacity, config.MaxCacheMemoryCapacity) ||
|
||||
!jsonutils.Equal(caches.SharedManager.MainDiskDir, config.CacheDiskDir) ||
|
||||
!jsonutils.Equal(caches.SharedManager.SubDiskDirs, subDirs) ||
|
||||
!jsonutils.Equal(this.oldHTTPCachePolicies, config.HTTPCachePolicies)
|
||||
|
||||
caches.SharedManager.MaxDiskCapacity = config.MaxCacheDiskCapacity
|
||||
caches.SharedManager.MaxMemoryCapacity = config.MaxCacheMemoryCapacity
|
||||
caches.SharedManager.MainDiskDir = config.CacheDiskDir
|
||||
caches.SharedManager.SubDiskDirs = subDirs
|
||||
|
||||
if cachePoliciesChanged {
|
||||
// copy
|
||||
this.oldHTTPCachePolicies = []*serverconfigs.HTTPCachePolicy{}
|
||||
err := jsonutils.Copy(&this.oldHTTPCachePolicies, config.HTTPCachePolicies)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy HTTPCachePolicies failed: "+err.Error())
|
||||
}
|
||||
|
||||
// update
|
||||
if len(config.HTTPCachePolicies) > 0 {
|
||||
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
|
||||
} else {
|
||||
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WAF策略
|
||||
// 包含了服务里的WAF策略,所以需要整体更新
|
||||
var allFirewallPolicies = config.FindAllFirewallPolicies()
|
||||
if !jsonutils.Equal(allFirewallPolicies, this.oldHTTPFirewallPolicies) {
|
||||
// copy
|
||||
@@ -1044,105 +1088,107 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig, reloadAll bool) {
|
||||
waf.SharedWAFManager.UpdatePolicies(allFirewallPolicies)
|
||||
}
|
||||
|
||||
if !jsonutils.Equal(config.FirewallActions, this.oldFirewallActions) {
|
||||
// copy
|
||||
this.oldFirewallActions = []*firewallconfigs.FirewallActionConfig{}
|
||||
err := jsonutils.Copy(&this.oldFirewallActions, config.FirewallActions)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy FirewallActionConfigs failed: "+err.Error())
|
||||
if reloadAll {
|
||||
if !jsonutils.Equal(config.FirewallActions, this.oldFirewallActions) {
|
||||
// copy
|
||||
this.oldFirewallActions = []*firewallconfigs.FirewallActionConfig{}
|
||||
err := jsonutils.Copy(&this.oldFirewallActions, config.FirewallActions)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy FirewallActionConfigs failed: "+err.Error())
|
||||
}
|
||||
|
||||
// update
|
||||
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
|
||||
}
|
||||
|
||||
// update
|
||||
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
|
||||
}
|
||||
// 统计指标
|
||||
if !jsonutils.Equal(this.oldMetricItems, config.MetricItems) {
|
||||
// copy
|
||||
this.oldMetricItems = []*serverconfigs.MetricItemConfig{}
|
||||
err := jsonutils.Copy(&this.oldMetricItems, config.MetricItems)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy MetricItemConfigs failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 统计指标
|
||||
if !jsonutils.Equal(this.oldMetricItems, config.MetricItems) {
|
||||
// copy
|
||||
this.oldMetricItems = []*serverconfigs.MetricItemConfig{}
|
||||
err := jsonutils.Copy(&this.oldMetricItems, config.MetricItems)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "onReload: copy MetricItemConfigs failed: "+err.Error())
|
||||
// update
|
||||
metrics.SharedManager.Update(config.MetricItems)
|
||||
}
|
||||
|
||||
// update
|
||||
metrics.SharedManager.Update(config.MetricItems)
|
||||
}
|
||||
// max cpu
|
||||
if config.MaxCPU != this.oldMaxCPU {
|
||||
if config.MaxCPU > 0 && config.MaxCPU < int32(runtime.NumCPU()) {
|
||||
runtime.GOMAXPROCS(int(config.MaxCPU))
|
||||
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(config.MaxCPU)+"'")
|
||||
} else {
|
||||
var threads = runtime.NumCPU() * 4
|
||||
runtime.GOMAXPROCS(threads)
|
||||
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(threads)+"'")
|
||||
}
|
||||
|
||||
// max cpu
|
||||
if config.MaxCPU != this.oldMaxCPU {
|
||||
if config.MaxCPU > 0 && config.MaxCPU < int32(runtime.NumCPU()) {
|
||||
runtime.GOMAXPROCS(int(config.MaxCPU))
|
||||
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(config.MaxCPU)+"'")
|
||||
this.oldMaxCPU = config.MaxCPU
|
||||
}
|
||||
|
||||
// max threads
|
||||
if config.MaxThreads != this.oldMaxThreads {
|
||||
if config.MaxThreads > 0 {
|
||||
debug.SetMaxThreads(config.MaxThreads)
|
||||
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(config.MaxThreads)+"'")
|
||||
} else {
|
||||
debug.SetMaxThreads(nodeconfigs.DefaultMaxThreads)
|
||||
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(nodeconfigs.DefaultMaxThreads)+"'")
|
||||
}
|
||||
this.oldMaxThreads = config.MaxThreads
|
||||
}
|
||||
|
||||
// timezone
|
||||
var timeZone = config.TimeZone
|
||||
if len(timeZone) == 0 {
|
||||
timeZone = "Asia/Shanghai"
|
||||
}
|
||||
|
||||
if this.oldTimezone != timeZone {
|
||||
location, err := time.LoadLocation(timeZone)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[TIMEZONE]change time zone failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("NODE", "[TIMEZONE]change time zone to '"+timeZone+"'")
|
||||
time.Local = location
|
||||
this.oldTimezone = timeZone
|
||||
}
|
||||
|
||||
// product information
|
||||
if config.ProductConfig != nil {
|
||||
teaconst.GlobalProductName = config.ProductConfig.Name
|
||||
}
|
||||
|
||||
// DNS resolver
|
||||
if config.DNSResolver != nil {
|
||||
var err error
|
||||
switch config.DNSResolver.Type {
|
||||
case nodeconfigs.DNSResolverTypeGoNative:
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
case nodeconfigs.DNSResolverTypeCGO:
|
||||
err = os.Setenv("GODEBUG", "netdns=cgo")
|
||||
default:
|
||||
// 默认使用go原生
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
}
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
var threads = runtime.NumCPU() * 4
|
||||
runtime.GOMAXPROCS(threads)
|
||||
remotelogs.Println("NODE", "[CPU]set max cpu to '"+types.String(threads)+"'")
|
||||
}
|
||||
|
||||
this.oldMaxCPU = config.MaxCPU
|
||||
}
|
||||
|
||||
// max threads
|
||||
if config.MaxThreads != this.oldMaxThreads {
|
||||
if config.MaxThreads > 0 {
|
||||
debug.SetMaxThreads(config.MaxThreads)
|
||||
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(config.MaxThreads)+"'")
|
||||
} else {
|
||||
debug.SetMaxThreads(nodeconfigs.DefaultMaxThreads)
|
||||
remotelogs.Println("NODE", "[THREADS]set max threads to '"+types.String(nodeconfigs.DefaultMaxThreads)+"'")
|
||||
}
|
||||
this.oldMaxThreads = config.MaxThreads
|
||||
}
|
||||
|
||||
// timezone
|
||||
var timeZone = config.TimeZone
|
||||
if len(timeZone) == 0 {
|
||||
timeZone = "Asia/Shanghai"
|
||||
}
|
||||
|
||||
if this.oldTimezone != timeZone {
|
||||
location, err := time.LoadLocation(timeZone)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[TIMEZONE]change time zone failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("NODE", "[TIMEZONE]change time zone to '"+timeZone+"'")
|
||||
time.Local = location
|
||||
this.oldTimezone = timeZone
|
||||
}
|
||||
|
||||
// product information
|
||||
if config.ProductConfig != nil {
|
||||
teaconst.GlobalProductName = config.ProductConfig.Name
|
||||
}
|
||||
|
||||
// DNS resolver
|
||||
if config.DNSResolver != nil {
|
||||
var err error
|
||||
switch config.DNSResolver.Type {
|
||||
case nodeconfigs.DNSResolverTypeGoNative:
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
case nodeconfigs.DNSResolverTypeCGO:
|
||||
err = os.Setenv("GODEBUG", "netdns=cgo")
|
||||
default:
|
||||
// 默认使用go原生
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
err := os.Setenv("GODEBUG", "netdns=go")
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
// 默认使用go原生
|
||||
err := os.Setenv("GODEBUG", "netdns=go")
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// API Node地址,这里不限制是否为空,因为在为空时仍然要有对应的处理
|
||||
this.changeAPINodeAddrs(config.APINodeAddrs)
|
||||
// API Node地址,这里不限制是否为空,因为在为空时仍然要有对应的处理
|
||||
this.changeAPINodeAddrs(config.APINodeAddrs)
|
||||
}
|
||||
}
|
||||
|
||||
// reload server config
|
||||
|
||||
@@ -57,6 +57,79 @@ func TestRegexp_ParseKeywords(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexp_Special(t *testing.T) {
|
||||
var unescape = func(v string) string {
|
||||
//replace urlencoded characters
|
||||
|
||||
var chars = [][2]string{
|
||||
{`\s`, `(\s|%09|%0A|\+)`},
|
||||
{`\(`, `(\(|%28)`},
|
||||
{`=`, `(=|%3D)`},
|
||||
{`<`, `(<|%3C)`},
|
||||
{`\*`, `(\*|%2A)`},
|
||||
{`\\`, `(\\|%2F)`},
|
||||
{`!`, `(!|%21)`},
|
||||
{`/`, `(/|%2F)`},
|
||||
{`;`, `(;|%3B)`},
|
||||
{`\+`, `(\+|%20)`},
|
||||
}
|
||||
|
||||
for _, c := range chars {
|
||||
if !strings.Contains(v, c[0]) {
|
||||
continue
|
||||
}
|
||||
var pieces = strings.Split(v, c[0])
|
||||
|
||||
// 修复piece中错误的\
|
||||
for pieceIndex, piece := range pieces {
|
||||
var l = len(piece)
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
if piece[l-1] != '\\' {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算\的数量
|
||||
var countBackSlashes = 0
|
||||
for i := l - 1; i >= 0; i-- {
|
||||
if piece[i] == '\\' {
|
||||
countBackSlashes++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if countBackSlashes%2 == 1 {
|
||||
// 去掉最后一个
|
||||
pieces[pieceIndex] = piece[:len(piece)-1]
|
||||
}
|
||||
}
|
||||
|
||||
v = strings.Join(pieces, c[1])
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
for _, s := range []string{
|
||||
`\\s`,
|
||||
`\s\W`,
|
||||
`aaaa/\W`,
|
||||
`aaaa\/\W`,
|
||||
`aaaa\=\W`,
|
||||
`aaaa\\=\W`,
|
||||
`aaaa\\\=\W`,
|
||||
`aaaa\\\\=\W`,
|
||||
} {
|
||||
var es = unescape(s)
|
||||
t.Log(s, "=>", es)
|
||||
_, err := re.Compile(es)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexp_ParseKeywords2(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ type RPCClient struct {
|
||||
SSLCertRPC pb.SSLCertServiceClient
|
||||
ScriptRPC pb.ScriptServiceClient
|
||||
UserRPC pb.UserServiceClient
|
||||
ClientAgentIPRPC pb.ClientAgentIPServiceClient
|
||||
}
|
||||
|
||||
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
@@ -83,6 +84,7 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
client.SSLCertRPC = pb.NewSSLCertServiceClient(client)
|
||||
client.ScriptRPC = pb.NewScriptServiceClient(client)
|
||||
client.UserRPC = pb.NewUserServiceClient(client)
|
||||
client.ClientAgentIPRPC = pb.NewClientAgentIPServiceClient(client)
|
||||
|
||||
err := client.init()
|
||||
if err != nil {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
@@ -146,11 +147,16 @@ func (this *HTTPRequestStatManager) AddRemoteAddr(serverId int64, remoteAddr str
|
||||
}
|
||||
|
||||
// AddUserAgent 添加UserAgent
|
||||
func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent string) {
|
||||
func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent string, ip string) {
|
||||
if len(userAgent) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否包含一些知名Agent
|
||||
if len(userAgent) > 0 && len(ip) > 0 && agents.IsAgentFromUserAgent(userAgent) {
|
||||
agents.SharedQueue.Push(ip)
|
||||
}
|
||||
|
||||
select {
|
||||
case this.userAgentChan <- strconv.FormatInt(serverId, 10) + "@" + userAgent:
|
||||
default:
|
||||
|
||||
@@ -37,11 +37,11 @@ func TestHTTPRequestStatManager_Loop_Region(t *testing.T) {
|
||||
|
||||
func TestHTTPRequestStatManager_Loop_UserAgent(t *testing.T) {
|
||||
var manager = NewHTTPRequestStatManager()
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/76 Safari/537.36")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; rv:11.0) like Gecko")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11_1_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Macintosh; Intel Mac OS X 11) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/76 Safari/537.36", "")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 10.0; WOW64; rv:49.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.96 Safari/537.36", "")
|
||||
manager.AddUserAgent(1, "Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; rv:11.0) like Gecko", "")
|
||||
err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
39
internal/utils/agents/agent.go
Normal file
39
internal/utils/agents/agent.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
Code string
|
||||
Keywords []string // user agent keywords
|
||||
|
||||
suffixes []string // PTR suffixes
|
||||
reg *regexp.Regexp
|
||||
}
|
||||
|
||||
func NewAgent(code string, suffixes []string, reg *regexp.Regexp, keywords []string) *Agent {
|
||||
return &Agent{
|
||||
Code: code,
|
||||
suffixes: suffixes,
|
||||
reg: reg,
|
||||
Keywords: keywords,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Agent) Match(ptr string) bool {
|
||||
if len(this.suffixes) > 0 {
|
||||
for _, suffix := range this.suffixes {
|
||||
if strings.HasSuffix(ptr, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if this.reg != nil {
|
||||
return this.reg.MatchString(ptr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
9
internal/utils/agents/agent_ip.go
Normal file
9
internal/utils/agents/agent_ip.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
type AgentIP struct {
|
||||
Id int64
|
||||
IP string
|
||||
AgentCode string
|
||||
}
|
||||
31
internal/utils/agents/agents.go
Normal file
31
internal/utils/agents/agents.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import "strings"
|
||||
|
||||
var AllAgents = []*Agent{
|
||||
NewAgent("baidu", []string{".baidu.com."}, nil, []string{"Baidu"}),
|
||||
NewAgent("google", []string{".googlebot.com."}, nil, []string{"Google"}),
|
||||
NewAgent("bing", []string{".search.msn.com."}, nil, []string{"bingbot"}),
|
||||
NewAgent("sogou", []string{".sogou.com."}, nil, []string{"Sogou"}),
|
||||
NewAgent("youdao", []string{".163.com."}, nil, []string{"Youdao"}),
|
||||
NewAgent("yahoo", []string{".yahoo.com."}, nil, []string{"Yahoo"}),
|
||||
NewAgent("bytedance", []string{".bytedance.com."}, nil, []string{"Bytespider"}),
|
||||
NewAgent("sm", []string{".sm.cn."}, nil, []string{"YisouSpider"}),
|
||||
NewAgent("yandex", []string{".yandex.com.", ".yndx.net."}, nil, []string{"Yandex"}),
|
||||
NewAgent("semrush", []string{".semrush.com."}, nil, []string{"SEMrush"}),
|
||||
}
|
||||
|
||||
func IsAgentFromUserAgent(userAgent string) bool {
|
||||
for _, agent := range AllAgents {
|
||||
if len(agent.Keywords) > 0 {
|
||||
for _, keyword := range agent.Keywords {
|
||||
if strings.Contains(userAgent, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
19
internal/utils/agents/agents_test.go
Normal file
19
internal/utils/agents/agents_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsAgentFromUserAgent(t *testing.T) {
|
||||
t.Log(agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;) AppleWebKit/534.46 (KHTML,like Gecko) Version/5.1 Mobile Safari/10600.6.3 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)"))
|
||||
t.Log(agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;)"))
|
||||
}
|
||||
|
||||
func BenchmarkIsAgentFromUserAgent(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
agents.IsAgentFromUserAgent("Mozilla/5.0 (Linux;u;Android 4.2.2;zh-cn;) AppleWebKit/534.46 (KHTML,like Gecko) Version/5.1 Mobile Safari/10600.6.3 (compatible; Yaho)")
|
||||
}
|
||||
}
|
||||
156
internal/utils/agents/db.go
Normal file
156
internal/utils/agents/db.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const (
|
||||
tableAgentIPs = "agentIPs"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
path string
|
||||
|
||||
insertAgentIPStmt *sql.Stmt
|
||||
listAgentIPsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewDB(path string) *DB {
|
||||
var db = &DB{path: path}
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (this *DB) Init() error {
|
||||
// 检查目录是否存在
|
||||
var dir = filepath.Dir(this.path)
|
||||
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("DB", "create database dir '"+dir+"'")
|
||||
}
|
||||
|
||||
// TODO 思考 data.db 的数据安全性
|
||||
db, err := sql.Open("sqlite3", "file:"+this.path+"?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
/**_, err = db.Exec("VACUUM")
|
||||
if err != nil {
|
||||
return err
|
||||
}**/
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableAgentIPs + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"ip" varchar(64),
|
||||
"agentCode" varchar(128)
|
||||
);`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 预编译语句
|
||||
|
||||
// agent ip record statements
|
||||
this.insertAgentIPStmt, err = db.Prepare(`INSERT INTO "` + tableAgentIPs + `" ("id", "ip", "agentCode") VALUES (?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listAgentIPsStmt, err = db.Prepare(`SELECT "id", "ip", "agentCode" FROM "` + tableAgentIPs + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
|
||||
_, err := this.insertAgentIPStmt.Exec(ipId, ip, agentCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*AgentIP, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.listAgentIPsStmt.Query(size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
for rows.Next() {
|
||||
var agentIP = &AgentIP{}
|
||||
err = rows.Scan(&agentIP.Id, &agentIP.IP, &agentIP.AgentCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agentIPs = append(agentIPs, agentIP)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *DB) Close() error {
|
||||
if this.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, stmt := range []*sql.Stmt{
|
||||
this.insertAgentIPStmt,
|
||||
this.listAgentIPsStmt,
|
||||
} {
|
||||
if stmt != nil {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return this.db.Close()
|
||||
}
|
||||
|
||||
// 打印日志
|
||||
func (this *DB) log(args ...any) {
|
||||
if !Tea.IsTesting() {
|
||||
return
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
args[0] = "[" + types.String(args[0]) + "]"
|
||||
log.Println(args...)
|
||||
}
|
||||
54
internal/utils/agents/ip_cache_map.go
Normal file
54
internal/utils/agents/ip_cache_map.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type IPCacheMap struct {
|
||||
m map[string]zero.Zero
|
||||
list []string
|
||||
|
||||
locker sync.RWMutex
|
||||
maxLen int
|
||||
}
|
||||
|
||||
func NewIPCacheMap(maxLen int) *IPCacheMap {
|
||||
if maxLen <= 0 {
|
||||
maxLen = 65535
|
||||
}
|
||||
return &IPCacheMap{
|
||||
m: map[string]zero.Zero{},
|
||||
maxLen: maxLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPCacheMap) Add(ip string) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 是否已经存在
|
||||
_, ok := this.m[ip]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 超出长度删除第一个
|
||||
if len(this.list) >= this.maxLen {
|
||||
delete(this.m, this.list[0])
|
||||
this.list = this.list[1:]
|
||||
}
|
||||
|
||||
// 加入新数据
|
||||
this.m[ip] = zero.Zero{}
|
||||
this.list = append(this.list, ip)
|
||||
}
|
||||
|
||||
func (this *IPCacheMap) Contains(ip string) bool {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.m[ip]
|
||||
return ok
|
||||
}
|
||||
33
internal/utils/agents/ip_cache_map_test.go
Normal file
33
internal/utils/agents/ip_cache_map_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIPCacheMap(t *testing.T) {
|
||||
var cacheMap = NewIPCacheMap(3)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("1")
|
||||
cacheMap.Add("2")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("3")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("4")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("3")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
}
|
||||
200
internal/utils/agents/manager.go
Normal file
200
internal/utils/agents/manager.go
Normal file
@@ -0,0 +1,200 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedManager = NewManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedManager.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Manager Agent管理器
|
||||
type Manager struct {
|
||||
ipMap map[string]string // ip => agentCode
|
||||
locker sync.RWMutex
|
||||
|
||||
db *DB
|
||||
|
||||
lastId int64
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
ipMap: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Manager) SetDB(db *DB) {
|
||||
this.db = db
|
||||
}
|
||||
|
||||
func (this *Manager) Start() {
|
||||
remotelogs.Println("AGENT_MANAGER", "starting ...")
|
||||
|
||||
err := this.loadDB()
|
||||
if err != nil {
|
||||
remotelogs.Error("AGENT_MANAGER", "load database failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 从本地数据库中加载
|
||||
err = this.Load()
|
||||
if err != nil {
|
||||
remotelogs.Error("AGENT_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 先从API获取
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 定时获取
|
||||
var duration = 30 * time.Minute
|
||||
if Tea.IsTesting() {
|
||||
duration = 30 * time.Second
|
||||
}
|
||||
var ticker = time.NewTicker(duration)
|
||||
for range ticker.C {
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Manager) Load() error {
|
||||
var offset int64 = 0
|
||||
var size int64 = 10000
|
||||
for {
|
||||
agentIPs, err := this.db.ListAgentIPs(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(agentIPs) == 0 {
|
||||
break
|
||||
}
|
||||
for _, agentIP := range agentIPs {
|
||||
this.locker.Lock()
|
||||
this.ipMap[agentIP.IP] = agentIP.AgentCode
|
||||
this.locker.Unlock()
|
||||
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
}
|
||||
offset += size
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Manager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环获取数据
|
||||
func (this *Manager) Loop() (hasNext bool, err error) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ipsResp, err := rpcClient.ClientAgentIPRPC.ListClientAgentIPsAfterId(rpcClient.Context(), &pb.ListClientAgentIPsAfterIdRequest{
|
||||
Id: this.lastId,
|
||||
Size: 10000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(ipsResp.ClientAgentIPs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, agentIP := range ipsResp.ClientAgentIPs {
|
||||
if agentIP.ClientAgent == nil {
|
||||
// 设置ID
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入到数据库
|
||||
err = this.db.InsertAgentIP(agentIP.Id, agentIP.Ip, agentIP.ClientAgent.Code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 写入Map
|
||||
this.locker.Lock()
|
||||
this.ipMap[agentIP.Ip] = agentIP.ClientAgent.Code
|
||||
this.locker.Unlock()
|
||||
|
||||
// 设置ID
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// AddIP 添加记录
|
||||
func (this *Manager) AddIP(ip string, agentCode string) {
|
||||
this.locker.Lock()
|
||||
this.ipMap[ip] = agentCode
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// LookupIP 查询IP所属Agent
|
||||
func (this *Manager) LookupIP(ip string) (agentCode string) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
return this.ipMap[ip]
|
||||
}
|
||||
|
||||
// ContainsIP 检查是否有IP相关数据
|
||||
func (this *Manager) ContainsIP(ip string) bool {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.ipMap[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (this *Manager) loadDB() error {
|
||||
var db = NewDB(Tea.Root + "/data/agents.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.db = db
|
||||
return nil
|
||||
}
|
||||
32
internal/utils/agents/manager_test.go
Normal file
32
internal/utils/agents/manager_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
var db = agents.NewDB(Tea.Root + "/data/agents.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = agents.NewManager()
|
||||
manager.SetDB(db)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log(manager.LookupIP("192.168.3.100"))
|
||||
}
|
||||
133
internal/utils/agents/queue.go
Normal file
133
internal/utils/agents/queue.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedQueue.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var SharedQueue = NewQueue()
|
||||
|
||||
type Queue struct {
|
||||
c chan string // chan ip
|
||||
cacheMap *IPCacheMap
|
||||
}
|
||||
|
||||
func NewQueue() *Queue {
|
||||
return &Queue{
|
||||
c: make(chan string, 128),
|
||||
cacheMap: NewIPCacheMap(65535),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Queue) Start() {
|
||||
for ip := range this.c {
|
||||
err := this.Process(ip)
|
||||
if err != nil {
|
||||
// 不需要上报错误
|
||||
if Tea.IsTesting() {
|
||||
remotelogs.Debug("SharedParseQueue", err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Push 将IP加入到处理队列
|
||||
func (this *Queue) Push(ip string) {
|
||||
// 是否在处理中
|
||||
if this.cacheMap.Contains(ip) {
|
||||
return
|
||||
}
|
||||
this.cacheMap.Add(ip)
|
||||
|
||||
// 加入到队列
|
||||
select {
|
||||
case this.c <- ip:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理IP
|
||||
func (this *Queue) Process(ip string) error {
|
||||
// 是否已经在库中
|
||||
if SharedManager.ContainsIP(ip) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ptr, err := this.ParseIP(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(ptr) == 0 || ptr == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
//remotelogs.Debug("AGENT", ip+" => "+ptr)
|
||||
|
||||
var agentCode = this.ParsePtr(ptr)
|
||||
if len(agentCode) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 加入到本地
|
||||
SharedManager.AddIP(ip, agentCode)
|
||||
|
||||
var pbAgentIP = &pb.CreateClientAgentIPsRequest_AgentIPInfo{
|
||||
AgentCode: agentCode,
|
||||
Ip: ip,
|
||||
Ptr: ptr,
|
||||
}
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = rpcClient.ClientAgentIPRPC.CreateClientAgentIPs(rpcClient.Context(), &pb.CreateClientAgentIPsRequest{AgentIPs: []*pb.CreateClientAgentIPsRequest_AgentIPInfo{pbAgentIP}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseIP 分析IP的PTR值
|
||||
func (this *Queue) ParseIP(ip string) (ptr string, err error) {
|
||||
if len(ip) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
names, err := net.LookupAddr(ip)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return names[0], nil
|
||||
}
|
||||
|
||||
// ParsePtr 分析PTR对应的Agent
|
||||
func (this *Queue) ParsePtr(ptr string) (agentCode string) {
|
||||
for _, agent := range AllAgents {
|
||||
if agent.Match(ptr) {
|
||||
return agent.Code
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
76
internal/utils/agents/queue_test.go
Normal file
76
internal/utils/agents/queue_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseQueue_Process(t *testing.T) {
|
||||
var queue = agents.NewQueue()
|
||||
go queue.Start()
|
||||
time.Sleep(1 * time.Second)
|
||||
queue.Push("220.181.13.100")
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestParseQueue_ParseIP(t *testing.T) {
|
||||
var queue = agents.NewQueue()
|
||||
for _, ip := range []string{
|
||||
"192.168.1.100",
|
||||
"42.120.160.1",
|
||||
"42.236.10.98",
|
||||
"124.115.0.100",
|
||||
} {
|
||||
ptr, err := queue.ParseIP(ip)
|
||||
if err != nil {
|
||||
t.Log(ip, "=>", err)
|
||||
continue
|
||||
}
|
||||
t.Log(ip, "=>", ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseQueue_ParsePtr(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var queue = agents.NewQueue()
|
||||
for _, s := range [][]string{
|
||||
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
|
||||
{"crawl-66-249-71-219.googlebot.com.", "google"},
|
||||
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
|
||||
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
|
||||
{"m13102.mail.163.com.", "youdao"},
|
||||
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
|
||||
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
|
||||
{"93-158-161-39.spider.yandex.com.", "yandex"},
|
||||
{"25.bl.bot.semrush.com.", "semrush"},
|
||||
} {
|
||||
a.IsTrue(queue.ParsePtr(s[0]) == s[1])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueue_ParsePtr(b *testing.B) {
|
||||
var queue = agents.NewQueue()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, s := range [][]string{
|
||||
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
|
||||
{"crawl-66-249-71-219.googlebot.com.", "google"},
|
||||
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
|
||||
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
|
||||
{"m13102.mail.163.com.", "youdao"},
|
||||
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
|
||||
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
|
||||
{"93-158-161-39.spider.yandex.com.", "yandex"},
|
||||
{"93.158.164.218-red.dhcp.yndx.net.", "yandex"},
|
||||
{"25.bl.bot.semrush.com.", "semrush"},
|
||||
} {
|
||||
queue.ParsePtr(s[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
|
||||
var RegexpDigitNumber = regexp.MustCompile(`^\d+$`)
|
||||
|
||||
func Get(object interface{}, keys []string) interface{} {
|
||||
if len(keys) == 0 {
|
||||
|
||||
@@ -67,10 +67,3 @@ func TestRange_ComposeContentRangeHeader(t *testing.T) {
|
||||
var r = rangeutils.NewRange(1, 100)
|
||||
t.Log(r.ComposeContentRangeHeader("1000"))
|
||||
}
|
||||
|
||||
func TestRange_SetLength(t *testing.T) {
|
||||
var r = rangeutils.NewRange(1, 100)
|
||||
t.Log(r)
|
||||
|
||||
t.Log(r.SetLength(1024))
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package writers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"time"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
// RateLimitWriter 限速写入
|
||||
type RateLimitWriter struct {
|
||||
rawWriter io.WriteCloser
|
||||
ctx context.Context
|
||||
|
||||
rateBytes int
|
||||
|
||||
@@ -18,9 +20,10 @@ type RateLimitWriter struct {
|
||||
before time.Time
|
||||
}
|
||||
|
||||
func NewRateLimitWriter(rawWriter io.WriteCloser, rateBytes int64) io.WriteCloser {
|
||||
func NewRateLimitWriter(ctx context.Context, rawWriter io.WriteCloser, rateBytes int64) io.WriteCloser {
|
||||
return &RateLimitWriter{
|
||||
rawWriter: rawWriter,
|
||||
ctx: ctx,
|
||||
rateBytes: types.Int(rateBytes),
|
||||
before: time.Now(),
|
||||
}
|
||||
@@ -71,6 +74,14 @@ func (this *RateLimitWriter) write(p []byte) (n int, err error) {
|
||||
n, err = this.rawWriter.Write(p)
|
||||
|
||||
if err == nil {
|
||||
select {
|
||||
case <-this.ctx.Done():
|
||||
err = io.EOF
|
||||
return
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
this.written += n
|
||||
|
||||
if this.written >= this.rateBytes {
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type BaseAction struct {
|
||||
currentActionId int64
|
||||
}
|
||||
@@ -19,16 +15,3 @@ func (this *BaseAction) ActionId() int64 {
|
||||
func (this *BaseAction) SetActionId(actionId int64) {
|
||||
this.currentActionId = actionId
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
|
||||
// 断开连接
|
||||
hijack, ok := writer.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, err := hijack.Hijack()
|
||||
if err == nil && conn != nil {
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,12 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
|
||||
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
|
||||
|
||||
if request.WAFRaw().ProtoMajor == 1 {
|
||||
_ = this.CloseConn(writer)
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
request.WAFClose()
|
||||
}
|
||||
|
||||
return false, false
|
||||
|
||||
@@ -87,7 +87,12 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
if request.WAFRaw().ProtoMajor == 1 {
|
||||
_ = this.CloseConn(writer)
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
request.WAFClose()
|
||||
}
|
||||
|
||||
return false, false
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
type recordIPTask struct {
|
||||
ip string
|
||||
listId int64
|
||||
expiredAt int64
|
||||
expiresAt int64
|
||||
level string
|
||||
serverId int64
|
||||
|
||||
@@ -54,7 +54,7 @@ func init() {
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiredAt,
|
||||
ExpiredAt: task.expiresAt,
|
||||
Reason: reason,
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
@@ -105,11 +105,13 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
return true, false
|
||||
}
|
||||
|
||||
timeout := this.Timeout
|
||||
var timeout = this.Timeout
|
||||
var isForever = false
|
||||
if timeout <= 0 {
|
||||
isForever = true
|
||||
timeout = 86400 // 1天
|
||||
}
|
||||
expiredAt := time.Now().Unix() + int64(timeout)
|
||||
var expiresAt = time.Now().Unix() + int64(timeout)
|
||||
|
||||
if this.Type == "black" {
|
||||
writer.WriteHeader(http.StatusForbidden)
|
||||
@@ -117,10 +119,10 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
request.WAFClose()
|
||||
|
||||
// 先加入本地的黑名单
|
||||
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
|
||||
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiresAt)
|
||||
} else {
|
||||
// 加入本地白名单
|
||||
SharedIPWhiteList.Add("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
|
||||
SharedIPWhiteList.Add("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiresAt)
|
||||
}
|
||||
|
||||
// 上报
|
||||
@@ -130,11 +132,16 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
|
||||
serverId = request.WAFServerId()
|
||||
}
|
||||
|
||||
var realExpiresAt = expiresAt
|
||||
if isForever {
|
||||
realExpiresAt = 0
|
||||
}
|
||||
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
expiresAt: realExpiresAt,
|
||||
level: this.Level,
|
||||
serverId: serverId,
|
||||
sourceServerId: request.WAFServerId(),
|
||||
|
||||
@@ -100,7 +100,7 @@ func (this *IPList) RecordIP(ipType string,
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: ip,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiredAt: expiresAt,
|
||||
expiresAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
serverId: serverId,
|
||||
sourceServerId: serverId,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -47,6 +48,10 @@ type Rule struct {
|
||||
isIP bool
|
||||
ipValue net.IP
|
||||
|
||||
ipRangeListValue *values.IPRangeList
|
||||
stringValues []string
|
||||
ipList *values.StringList
|
||||
|
||||
floatValue float64
|
||||
reg *re.Regexp
|
||||
}
|
||||
@@ -70,6 +75,21 @@ func (this *Rule) Init() error {
|
||||
this.floatValue = types.Float64(this.Value)
|
||||
case RuleOperatorNeq:
|
||||
this.floatValue = types.Float64(this.Value)
|
||||
case RuleOperatorContainsAny, RuleOperatorContainsAll:
|
||||
this.stringValues = []string{}
|
||||
if len(this.Value) > 0 {
|
||||
var lines = strings.Split(this.Value, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) > 0 {
|
||||
if this.IsCaseInsensitive {
|
||||
this.stringValues = append(this.stringValues, strings.ToLower(line))
|
||||
} else {
|
||||
this.stringValues = append(this.stringValues, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case RuleOperatorMatch:
|
||||
v := this.Value
|
||||
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
|
||||
@@ -103,33 +123,10 @@ func (this *Rule) Init() error {
|
||||
if !this.isIP {
|
||||
return errors.New("value should be a valid ip")
|
||||
}
|
||||
case RuleOperatorInIPList:
|
||||
this.ipList = values.ParseStringList(this.Value, true)
|
||||
case RuleOperatorIPRange, RuleOperatorNotIPRange:
|
||||
if strings.Contains(this.Value, ",") {
|
||||
ipList := strings.SplitN(this.Value, ",", 2)
|
||||
ipString1 := strings.TrimSpace(ipList[0])
|
||||
ipString2 := strings.TrimSpace(ipList[1])
|
||||
|
||||
if len(ipString1) > 0 {
|
||||
ip1 := net.ParseIP(ipString1)
|
||||
if ip1 == nil {
|
||||
return errors.New("start ip is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipString2) > 0 {
|
||||
ip2 := net.ParseIP(ipString2)
|
||||
if ip2 == nil {
|
||||
return errors.New("end ip is invalid")
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(this.Value, "/") {
|
||||
_, _, err := net.ParseCIDR(this.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.New("invalid ip range")
|
||||
}
|
||||
this.ipRangeListValue = values.ParseIPRangeList(this.Value)
|
||||
}
|
||||
|
||||
if singleParamRegexp.MatchString(this.Param) {
|
||||
@@ -362,13 +359,13 @@ func (this *Rule) Test(value interface{}) bool {
|
||||
return types.Float64(value) != this.floatValue
|
||||
case RuleOperatorEqString:
|
||||
if this.IsCaseInsensitive {
|
||||
return strings.ToLower(types.String(value)) == strings.ToLower(this.Value)
|
||||
return strings.EqualFold(types.String(value), this.Value)
|
||||
} else {
|
||||
return types.String(value) == this.Value
|
||||
}
|
||||
case RuleOperatorNeqString:
|
||||
if this.IsCaseInsensitive {
|
||||
return strings.ToLower(types.String(value)) != strings.ToLower(this.Value)
|
||||
return !strings.EqualFold(types.String(value), this.Value)
|
||||
} else {
|
||||
return types.String(value) != this.Value
|
||||
}
|
||||
@@ -472,6 +469,33 @@ func (this *Rule) Test(value interface{}) bool {
|
||||
} else {
|
||||
return strings.HasSuffix(types.String(value), this.Value)
|
||||
}
|
||||
case RuleOperatorContainsAny:
|
||||
var stringValue = types.String(value)
|
||||
if this.IsCaseInsensitive {
|
||||
stringValue = strings.ToLower(stringValue)
|
||||
}
|
||||
if len(stringValue) > 0 && len(this.stringValues) > 0 {
|
||||
for _, v := range this.stringValues {
|
||||
if strings.Contains(stringValue, v) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
case RuleOperatorContainsAll:
|
||||
var stringValue = types.String(value)
|
||||
if this.IsCaseInsensitive {
|
||||
stringValue = strings.ToLower(stringValue)
|
||||
}
|
||||
if len(stringValue) > 0 && len(this.stringValues) > 0 {
|
||||
for _, v := range this.stringValues {
|
||||
if !strings.Contains(stringValue, v) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case RuleOperatorContainsBinary:
|
||||
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value))
|
||||
if this.IsCaseInsensitive {
|
||||
@@ -563,12 +587,12 @@ func (this *Rule) Test(value interface{}) bool {
|
||||
case RuleOperatorNotIPRange:
|
||||
return !this.containsIP(value)
|
||||
case RuleOperatorIPMod:
|
||||
pieces := strings.SplitN(this.Value, ",", 2)
|
||||
var pieces = strings.SplitN(this.Value, ",", 2)
|
||||
if len(pieces) == 1 {
|
||||
rem := types.Int64(pieces[0])
|
||||
var rem = types.Int64(pieces[0])
|
||||
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == rem
|
||||
}
|
||||
div := types.Int64(pieces[0])
|
||||
var div = types.Int64(pieces[0])
|
||||
if div == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -578,6 +602,11 @@ func (this *Rule) Test(value interface{}) bool {
|
||||
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
|
||||
case RuleOperatorIPMod100:
|
||||
return this.ipToInt64(net.ParseIP(types.String(value)))%100 == types.Int64(this.Value)
|
||||
case RuleOperatorInIPList:
|
||||
if this.ipList != nil {
|
||||
return this.ipList.Contains(types.String(value))
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -590,65 +619,64 @@ func (this *Rule) SetCheckpointFinder(finder func(prefix string) checkpoints.Che
|
||||
this.checkpointFinder = finder
|
||||
}
|
||||
|
||||
var unescapeChars = [][2]string{
|
||||
{`\s`, `(\s|%09|%0A|\+)`},
|
||||
{`\(`, `(\(|%28)`},
|
||||
{`=`, `(=|%3D)`},
|
||||
{`<`, `(<|%3C)`},
|
||||
{`\*`, `(\*|%2A)`},
|
||||
{`\\`, `(\\|%2F)`},
|
||||
{`!`, `(!|%21)`},
|
||||
{`/`, `(/|%2F)`},
|
||||
{`;`, `(;|%3B)`},
|
||||
{`\+`, `(\+|%20)`},
|
||||
}
|
||||
|
||||
func (this *Rule) unescape(v string) string {
|
||||
//replace urlencoded characters
|
||||
v = strings.Replace(v, `\s`, `(\s|%09|%0A|\+)`, -1)
|
||||
v = strings.Replace(v, `\(`, `(\(|%28)`, -1)
|
||||
v = strings.Replace(v, `=`, `(=|%3D)`, -1)
|
||||
v = strings.Replace(v, `<`, `(<|%3C)`, -1)
|
||||
v = strings.Replace(v, `\*`, `(\*|%2A)`, -1)
|
||||
v = strings.Replace(v, `\\`, `(\\|%2F)`, -1)
|
||||
v = strings.Replace(v, `!`, `(!|%21)`, -1)
|
||||
v = strings.Replace(v, `/`, `(/|%2F)`, -1)
|
||||
v = strings.Replace(v, `;`, `(;|%3B)`, -1)
|
||||
v = strings.Replace(v, `\+`, `(\+|%20)`, -1)
|
||||
// replace urlencoded characters
|
||||
|
||||
for _, c := range unescapeChars {
|
||||
if !strings.Contains(v, c[0]) {
|
||||
continue
|
||||
}
|
||||
var pieces = strings.Split(v, c[0])
|
||||
|
||||
// 修复piece中错误的\
|
||||
for pieceIndex, piece := range pieces {
|
||||
var l = len(piece)
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
if piece[l-1] != '\\' {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算\的数量
|
||||
var countBackSlashes = 0
|
||||
for i := l - 1; i >= 0; i-- {
|
||||
if piece[i] == '\\' {
|
||||
countBackSlashes++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if countBackSlashes%2 == 1 {
|
||||
// 去掉最后一个
|
||||
pieces[pieceIndex] = piece[:len(piece)-1]
|
||||
}
|
||||
}
|
||||
|
||||
v = strings.Join(pieces, c[1])
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (this *Rule) containsIP(value interface{}) bool {
|
||||
ip := net.ParseIP(types.String(value))
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查IP范围格式
|
||||
if strings.Contains(this.Value, ",") {
|
||||
ipList := strings.SplitN(this.Value, ",", 2)
|
||||
ipString1 := strings.TrimSpace(ipList[0])
|
||||
ipString2 := strings.TrimSpace(ipList[1])
|
||||
|
||||
if len(ipString1) > 0 {
|
||||
ip1 := net.ParseIP(ipString1)
|
||||
if ip1 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if bytes.Compare(ip, ip1) < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipString2) > 0 {
|
||||
ip2 := net.ParseIP(ipString2)
|
||||
if ip2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if bytes.Compare(ip, ip2) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
} else if strings.Contains(this.Value, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(this.Value)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return ipNet.Contains(ip)
|
||||
} else {
|
||||
func (this *Rule) containsIP(value any) bool {
|
||||
if this.ipRangeListValue == nil {
|
||||
return false
|
||||
}
|
||||
return this.ipRangeListValue.Contains(types.String(value))
|
||||
}
|
||||
|
||||
func (this *Rule) ipToInt64(ip net.IP) int64 {
|
||||
|
||||
@@ -18,6 +18,9 @@ const (
|
||||
RuleOperatorNotContains RuleOperator = "not contains"
|
||||
RuleOperatorPrefix RuleOperator = "prefix"
|
||||
RuleOperatorSuffix RuleOperator = "suffix"
|
||||
RuleOperatorContainsAny RuleOperator = "contains any"
|
||||
RuleOperatorContainsAll RuleOperator = "contains all"
|
||||
RuleOperatorInIPList RuleOperator = "in ip list"
|
||||
RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map
|
||||
RuleOperatorVersionGt RuleOperator = "version gt"
|
||||
RuleOperatorVersionLt RuleOperator = "version lt"
|
||||
|
||||
@@ -5,5 +5,5 @@ package utils
|
||||
import "github.com/TeaOSLab/EdgeNode/internal/utils/sizes"
|
||||
|
||||
const (
|
||||
MaxBodySize = 4 * sizes.M
|
||||
MaxBodySize = 2 * sizes.M
|
||||
)
|
||||
|
||||
@@ -13,6 +13,10 @@ var cache = ttlcache.NewCache()
|
||||
|
||||
// MatchStringCache 正则表达式匹配字符串,并缓存结果
|
||||
func MatchStringCache(regex *re.Regexp, s string) bool {
|
||||
if regex == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果长度超过4096,大概率是不能重用的
|
||||
if len(s) > 4096 {
|
||||
return regex.MatchString(s)
|
||||
@@ -35,6 +39,10 @@ func MatchStringCache(regex *re.Regexp, s string) bool {
|
||||
|
||||
// MatchBytesCache 正则表达式匹配字节slice,并缓存结果
|
||||
func MatchBytesCache(regex *re.Regexp, byteSlice []byte) bool {
|
||||
if regex == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果长度超过4096,大概率是不能重用的
|
||||
if len(byteSlice) > 4096 {
|
||||
return regex.Match(byteSlice)
|
||||
|
||||
132
internal/waf/values/ip_range.go
Normal file
132
internal/waf/values/ip_range.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IPRangeType = string
|
||||
|
||||
const (
|
||||
IPRangeTypeCIDR IPRangeType = "cidr" // CIDR
|
||||
IPRangeTypeSingeIP IPRangeType = "singleIP" // 单个IP
|
||||
IPRangeTypeRange IPRangeType = "range" // IP范围,IP1-IP2
|
||||
)
|
||||
|
||||
type IPRange struct {
|
||||
Type IPRangeType
|
||||
CIDR *net.IPNet
|
||||
IPFrom net.IP
|
||||
IPTo net.IP
|
||||
}
|
||||
|
||||
func (this *IPRange) Contains(netIP net.IP) bool {
|
||||
if netIP == nil {
|
||||
return false
|
||||
}
|
||||
switch this.Type {
|
||||
case IPRangeTypeCIDR:
|
||||
if this.CIDR != nil {
|
||||
return this.CIDR.Contains(netIP)
|
||||
}
|
||||
case IPRangeTypeSingeIP:
|
||||
if this.IPFrom != nil {
|
||||
return bytes.Equal(this.IPFrom, netIP)
|
||||
}
|
||||
case IPRangeTypeRange:
|
||||
return bytes.Compare(this.IPFrom, netIP) <= 0 && bytes.Compare(this.IPTo, netIP) >= 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type IPRangeList struct {
|
||||
Ranges []*IPRange
|
||||
}
|
||||
|
||||
func NewIPRangeList() *IPRangeList {
|
||||
return &IPRangeList{}
|
||||
}
|
||||
|
||||
func ParseIPRangeList(value string) *IPRangeList {
|
||||
var list = NewIPRangeList()
|
||||
|
||||
if len(value) == 0 {
|
||||
return list
|
||||
}
|
||||
|
||||
var lines = strings.Split(value, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(line, ",") { // IPFrom,IPTo
|
||||
var pieces = strings.SplitN(line, ",", 2)
|
||||
if len(pieces) == 2 {
|
||||
var ipFrom = net.ParseIP(strings.TrimSpace(pieces[0]))
|
||||
var ipTo = net.ParseIP(strings.TrimSpace(pieces[1]))
|
||||
if ipFrom != nil && ipTo != nil {
|
||||
if bytes.Compare(ipFrom, ipTo) > 0 {
|
||||
ipFrom, ipTo = ipTo, ipFrom
|
||||
}
|
||||
list.Ranges = append(list.Ranges, &IPRange{
|
||||
Type: IPRangeTypeRange,
|
||||
IPFrom: ipFrom,
|
||||
IPTo: ipTo,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(line, "-") { // IPFrom-IPTo
|
||||
var pieces = strings.SplitN(line, "-", 2)
|
||||
if len(pieces) == 2 {
|
||||
var ipFrom = net.ParseIP(strings.TrimSpace(pieces[0]))
|
||||
var ipTo = net.ParseIP(strings.TrimSpace(pieces[1]))
|
||||
if ipFrom != nil && ipTo != nil {
|
||||
if bytes.Compare(ipFrom, ipTo) > 0 {
|
||||
ipFrom, ipTo = ipTo, ipFrom
|
||||
}
|
||||
list.Ranges = append(list.Ranges, &IPRange{
|
||||
Type: IPRangeTypeRange,
|
||||
IPFrom: ipFrom,
|
||||
IPTo: ipTo,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(line, "/") { // CIDR
|
||||
_, cidr, _ := net.ParseCIDR(line)
|
||||
if cidr != nil {
|
||||
list.Ranges = append(list.Ranges, &IPRange{
|
||||
Type: IPRangeTypeCIDR,
|
||||
CIDR: cidr,
|
||||
})
|
||||
}
|
||||
} else { // single ip
|
||||
var netIP = net.ParseIP(line)
|
||||
if netIP != nil {
|
||||
list.Ranges = append(list.Ranges, &IPRange{
|
||||
Type: IPRangeTypeSingeIP,
|
||||
IPFrom: netIP,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPRangeList) Contains(ip string) bool {
|
||||
var netIP = net.ParseIP(ip)
|
||||
if netIP == nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range this.Ranges {
|
||||
if r.Contains(netIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
54
internal/waf/values/ip_range_test.go
Normal file
54
internal/waf/values/ip_range_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPRange_ParseIPRangeList(t *testing.T) {
|
||||
{
|
||||
var r = values.ParseIPRangeList("")
|
||||
logs.PrintAsJSON(r, t)
|
||||
}
|
||||
{
|
||||
var r = values.ParseIPRangeList("192.168.2.1")
|
||||
logs.PrintAsJSON(r, t)
|
||||
}
|
||||
{
|
||||
var r = values.ParseIPRangeList(`192.168.2.1
|
||||
192.168.1.100/24
|
||||
192.168.1.1-192.168.2.100
|
||||
192.168.1.2,192.168.2.200
|
||||
192.168.2.200 - 192.168.2.100
|
||||
# 以下是错误的
|
||||
192.168
|
||||
192.168.100.1-1`)
|
||||
logs.PrintAsJSON(r, t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRange_Contains(t *testing.T) {
|
||||
{
|
||||
var a = assert.NewAssertion(t)
|
||||
var r = values.ParseIPRangeList(`192.168.2.1
|
||||
192.168.1.100/24
|
||||
192.168.1.1-192.168.2.100
|
||||
192.168.2.2,192.168.2.200
|
||||
192.168.3.200 - 192.168.3.100
|
||||
192.168.4.100
|
||||
192.168.5.1/26`)
|
||||
a.IsTrue(r.Contains("192.168.1.102"))
|
||||
a.IsTrue(r.Contains("192.168.2.101"))
|
||||
a.IsTrue(r.Contains("192.168.1.1"))
|
||||
a.IsTrue(r.Contains("192.168.2.100"))
|
||||
a.IsFalse(r.Contains("192.168.2.201"))
|
||||
a.IsTrue(r.Contains("192.168.3.101"))
|
||||
a.IsTrue(r.Contains("192.168.4.100"))
|
||||
a.IsTrue(r.Contains("192.168.5.63"))
|
||||
a.IsFalse(r.Contains("192.168.5.128"))
|
||||
}
|
||||
}
|
||||
48
internal/waf/values/number_list.go
Normal file
48
internal/waf/values/number_list.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type NumberList struct {
|
||||
ValueMap map[float64]zero.Zero
|
||||
}
|
||||
|
||||
func NewNumberList() *NumberList {
|
||||
return &NumberList{
|
||||
ValueMap: map[float64]zero.Zero{},
|
||||
}
|
||||
}
|
||||
|
||||
func ParseNumberList(v string) *NumberList {
|
||||
var list = NewNumberList()
|
||||
if len(v) == 0 {
|
||||
return list
|
||||
}
|
||||
|
||||
var lines = strings.Split(v, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var values = strings.Split(line, ",")
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) > 0 {
|
||||
list.ValueMap[types.Float64(value)] = zero.Zero{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *NumberList) Contains(f float64) bool {
|
||||
_, ok := this.ValueMap[f]
|
||||
return ok
|
||||
}
|
||||
29
internal/waf/values/number_list_test.go
Normal file
29
internal/waf/values/number_list_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseNumberList(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
{
|
||||
var list = values.ParseNumberList("")
|
||||
a.IsFalse(list.Contains(123))
|
||||
}
|
||||
|
||||
{
|
||||
var list = values.ParseNumberList(`123
|
||||
456
|
||||
|
||||
789.1234`)
|
||||
a.IsTrue(list.Contains(123))
|
||||
a.IsFalse(list.Contains(0))
|
||||
a.IsFalse(list.Contains(789.123))
|
||||
a.IsTrue(list.Contains(789.1234))
|
||||
}
|
||||
}
|
||||
55
internal/waf/values/string_list.go
Normal file
55
internal/waf/values/string_list.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StringList struct {
|
||||
ValueMap map[string]zero.Zero
|
||||
CaseInsensitive bool
|
||||
}
|
||||
|
||||
func NewStringList(caseInsensitive bool) *StringList {
|
||||
return &StringList{
|
||||
ValueMap: map[string]zero.Zero{},
|
||||
CaseInsensitive: caseInsensitive,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseStringList(v string, caseInsensitive bool) *StringList {
|
||||
var list = NewStringList(caseInsensitive)
|
||||
if len(v) == 0 {
|
||||
return list
|
||||
}
|
||||
|
||||
var lines = strings.Split(v, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var values = strings.Split(line, ",")
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) > 0 {
|
||||
if caseInsensitive {
|
||||
value = strings.ToLower(value)
|
||||
}
|
||||
list.ValueMap[value] = zero.Zero{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *StringList) Contains(f string) bool {
|
||||
if this.CaseInsensitive {
|
||||
f = strings.ToLower(f)
|
||||
}
|
||||
_, ok := this.ValueMap[f]
|
||||
return ok
|
||||
}
|
||||
43
internal/waf/values/string_list_test.go
Normal file
43
internal/waf/values/string_list_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package values_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/values"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseStringList(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
{
|
||||
var list = values.ParseStringList("", false)
|
||||
a.IsFalse(list.Contains("hello"))
|
||||
}
|
||||
|
||||
{
|
||||
var list = values.ParseStringList(`hello
|
||||
|
||||
world
|
||||
hi
|
||||
|
||||
people`, false)
|
||||
a.IsTrue(list.Contains("hello"))
|
||||
a.IsFalse(list.Contains("hello1"))
|
||||
a.IsFalse(list.Contains("Hello"))
|
||||
a.IsTrue(list.Contains("hi"))
|
||||
}
|
||||
{
|
||||
var list = values.ParseStringList(`Hello
|
||||
|
||||
world
|
||||
hi
|
||||
|
||||
people`, true)
|
||||
a.IsTrue(list.Contains("hello"))
|
||||
a.IsTrue(list.Contains("Hello"))
|
||||
a.IsTrue(list.Contains("HELLO"))
|
||||
a.IsFalse(list.Contains("How"))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user