Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d634d8fa5 | ||
|
|
a538282e4f | ||
|
|
df31921954 | ||
|
|
299abb7b04 | ||
|
|
85ce63b4d3 | ||
|
|
9d9ae288bd | ||
|
|
c0a35eb5e7 | ||
|
|
d1d0ff062b | ||
|
|
396e8a22c4 | ||
|
|
4f040db1ef | ||
|
|
e3cf111344 | ||
|
|
28cb3c383d | ||
|
|
10665c0f37 | ||
|
|
4096f11909 | ||
|
|
5df209b6d5 | ||
|
|
db353fe025 | ||
|
|
30c3e143b8 | ||
|
|
15fe7b33a4 | ||
|
|
17e0666ba4 | ||
|
|
3f24bfaaf5 | ||
|
|
ceadcfece9 | ||
|
|
bc706237ef | ||
|
|
b8c5a78f2e | ||
|
|
9ad1c3a3c8 | ||
|
|
8cfde43f5d | ||
|
|
6f843b071a | ||
|
|
a72b025900 | ||
|
|
dfb66775d7 | ||
|
|
c4dca2df30 | ||
|
|
a3525bdaa4 | ||
|
|
09390bbb97 | ||
|
|
70ae4391d7 | ||
|
|
85931f55e1 | ||
|
|
0f5f03c9ed | ||
|
|
1181a0585b | ||
|
|
fd6fa929de | ||
|
|
73888c98a8 | ||
|
|
04ebfbea8a | ||
|
|
2b650fd285 | ||
|
|
515a590681 | ||
|
|
62f9d1f09a | ||
|
|
25c11f3d69 | ||
|
|
fde18c3b82 |
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/apps"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
@@ -19,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := apps.NewAppCmd().
|
||||
var app = apps.NewAppCmd().
|
||||
Version(teaconst.Version).
|
||||
Product(teaconst.ProductName).
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog]").
|
||||
@@ -67,24 +68,30 @@ func main() {
|
||||
fmt.Println("done")
|
||||
})
|
||||
app.On("pprof", func() {
|
||||
// TODO 自己指定端口
|
||||
addr := "127.0.0.1:6060"
|
||||
var flagSet = flag.NewFlagSet("pprof", flag.ExitOnError)
|
||||
var addr string
|
||||
flagSet.StringVar(&addr, "addr", "", "")
|
||||
_ = flagSet.Parse(os.Args[2:])
|
||||
|
||||
if len(addr) == 0 {
|
||||
addr = "127.0.0.1:6060"
|
||||
}
|
||||
logs.Println("starting with pprof '" + addr + "'...")
|
||||
|
||||
go func() {
|
||||
err := http.ListenAndServe(addr, nil)
|
||||
if err != nil {
|
||||
logs.Println("[error]" + err.Error())
|
||||
logs.Println("[ERROR]" + err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
node := nodes.NewNode()
|
||||
var node = nodes.NewNode()
|
||||
node.Start()
|
||||
})
|
||||
app.On("dbstat", func() {
|
||||
teaconst.EnableDBStat = true
|
||||
|
||||
node := nodes.NewNode()
|
||||
var node = nodes.NewNode()
|
||||
node.Start()
|
||||
})
|
||||
app.On("trackers", func() {
|
||||
@@ -154,7 +161,7 @@ func main() {
|
||||
app.On("ip.drop", func() {
|
||||
var args = os.Args[2:]
|
||||
if len(args) == 0 {
|
||||
fmt.Println("Usage: edge-node ip.drop IP [--timeout=SECONDS]")
|
||||
fmt.Println("Usage: edge-node ip.drop IP [--timeout=SECONDS] [--async]")
|
||||
return
|
||||
}
|
||||
var ip = args[0]
|
||||
@@ -168,6 +175,11 @@ func main() {
|
||||
if ok {
|
||||
timeoutSeconds = types.Int(timeout[0])
|
||||
}
|
||||
var async = false
|
||||
_, ok = options["async"]
|
||||
if ok {
|
||||
async = true
|
||||
}
|
||||
|
||||
fmt.Println("drop ip '" + ip + "' for '" + types.String(timeoutSeconds) + "' seconds")
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
@@ -176,6 +188,7 @@ func main() {
|
||||
Params: map[string]interface{}{
|
||||
"ip": ip,
|
||||
"timeoutSeconds": timeoutSeconds,
|
||||
"async": async,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -306,7 +319,7 @@ func main() {
|
||||
}
|
||||
})
|
||||
app.Run(func() {
|
||||
node := nodes.NewNode()
|
||||
var node = nodes.NewNode()
|
||||
node.Start()
|
||||
})
|
||||
}
|
||||
|
||||
8
go.mod
8
go.mod
@@ -4,6 +4,8 @@ go 1.18
|
||||
|
||||
replace (
|
||||
github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
|
||||
github.com/fsnotify/fsnotify => /Users/WorkSpace/Projects/fsnotify
|
||||
rogchap.com/v8go => /Users/Workspace/Projects/v8go
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -16,7 +18,7 @@ require (
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang/protobuf v1.5.2
|
||||
github.com/google/nftables v0.0.0-20220407195405-950e408d48c6
|
||||
github.com/iwind/TeaGo v0.0.0-20220304043459-0dd944a5b475
|
||||
github.com/iwind/TeaGo v0.0.0-20220807030847-31de8e1cbe55
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11
|
||||
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4
|
||||
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8
|
||||
@@ -32,7 +34,8 @@ require (
|
||||
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad
|
||||
golang.org/x/text v0.3.7
|
||||
google.golang.org/grpc v1.45.0
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
rogchap.com/v8go v0.7.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -41,6 +44,7 @@ require (
|
||||
github.com/chai2010/webp v1.1.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-sql-driver/mysql v1.5.0 // indirect
|
||||
github.com/google/go-cmp v0.5.7 // indirect
|
||||
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
|
||||
23
go.sum
23
go.sum
@@ -47,6 +47,10 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
|
||||
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI=
|
||||
github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
@@ -89,8 +93,10 @@ github.com/google/nftables v0.0.0-20220407195405-950e408d48c6/go.mod h1:0F8on3JW
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/iwind/TeaGo v0.0.0-20220304043459-0dd944a5b475 h1:EseyfFaQOjWanGiby9KMw7PjDBMg/95tLDgIw/ns0Cw=
|
||||
github.com/iwind/TeaGo v0.0.0-20220304043459-0dd944a5b475/go.mod h1:HRHK0zoC/og3c9/hKosD9yYVMTnnzm3PgXUdhRYHaLc=
|
||||
github.com/iwind/TeaGo v0.0.0-20220807023459-448081424640 h1:nBVQzDI4mrQS+Egg+Li6BGiTToBsv+XTck+BItgI52k=
|
||||
github.com/iwind/TeaGo v0.0.0-20220807023459-448081424640/go.mod h1:+K2l6Num4Evl0jH7TYlZJ1oFJX8sA8YUC31Pb+I1mJk=
|
||||
github.com/iwind/TeaGo v0.0.0-20220807030847-31de8e1cbe55 h1:shQNx0flJFBwKsGE7Hs3bI2bDz+YF0zl/4qE8B2KRiY=
|
||||
github.com/iwind/TeaGo v0.0.0-20220807030847-31de8e1cbe55/go.mod h1:fi/Pq+/5m2HZoseM+39dMF57ANXRt6w4PkGu3NXPc5s=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 h1:DaQjoWZhLNxjhIXedVg4/vFEtHkZhK4IjIwsWdyzBLg=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11/go.mod h1:JtbX20untAjUVjZs1ZBtq80f5rJWvwtQNRL6EnuYRnY=
|
||||
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4 h1:VWGsCqTzObdlbf7UUE3oceIpcEKi4C/YBUszQXk118A=
|
||||
@@ -109,11 +115,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 h1:N527AHMa793TP5z5GNAn/VLPzlc0ewzWdeP/25gDfgQ=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||
github.com/klauspost/compress v1.15.6 h1:6D9PcO8QWu0JyaQ2zUMmu16T1T+zjjEpP91guRsvDfY=
|
||||
github.com/klauspost/compress v1.15.6/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
|
||||
github.com/klauspost/compress v1.15.8 h1:JahtItbkWjf2jzm/T+qgMxkP9EMHsqEUA6vCMGmXvhA=
|
||||
github.com/klauspost/compress v1.15.8/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -150,8 +153,6 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
|
||||
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
||||
github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
|
||||
github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/mssola/user_agent v0.5.3 h1:lBRPML9mdFuIZgI2cmlQ+atbpJdLdeVl2IDodjBR578=
|
||||
github.com/mssola/user_agent v0.5.3/go.mod h1:TTPno8LPY3wAIEKRpAtkdMT0f8SE24pLRGPahjCH4uw=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
@@ -180,7 +181,6 @@ github.com/shirou/gopsutil/v3 v3.22.2/go.mod h1:WapW1AOOPlHyXr+yOyw3uYx36enocrtS
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
@@ -268,6 +268,7 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -291,6 +292,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -373,10 +375,13 @@ gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.2.1/go.mod h1:lPVVZ2BS5TfnjLyizF7o7hv7j9/L+8cZY2hLyjP9cGY=
|
||||
honnef.co/go/tools v0.2.2 h1:MNh1AVMyVX23VUHE2O27jm6lNj3vjO5DexS4A1xvnzk=
|
||||
honnef.co/go/tools v0.2.2/go.mod h1:lPVVZ2BS5TfnjLyizF7o7hv7j9/L+8cZY2hLyjP9cGY=
|
||||
rogchap.com/v8go v0.7.0 h1:kgjbiO4zE5itA962ze6Hqmbs4HgZbGzmueCXsZtremg=
|
||||
rogchap.com/v8go v0.7.0/go.mod h1:MxgP3pL2MW4dpme/72QRs8sgNMmM0pRc8DPhcuLWPAs=
|
||||
|
||||
@@ -89,7 +89,10 @@ func (this *LogWriter) Write(message string) {
|
||||
}
|
||||
}
|
||||
|
||||
this.c <- message
|
||||
select {
|
||||
case this.c <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (this *LogWriter) Close() {
|
||||
|
||||
@@ -227,7 +227,7 @@ func (this *FileList) PurgeLFU(count int, callback func(hash string) error) erro
|
||||
if notFound {
|
||||
_, err = db.deleteHitByHashStmt.Exec(hash)
|
||||
if err != nil {
|
||||
return err
|
||||
return db.WrapError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,14 +359,14 @@ func (this *FileList) remove(hash string) (notFound bool, err error) {
|
||||
|
||||
_, err = db.deleteByHashStmt.Exec(hash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, db.WrapError(err)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.total, -1)
|
||||
|
||||
_, err = db.deleteHitByHashStmt.Exec(hash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, db.WrapError(err)
|
||||
}
|
||||
|
||||
if this.onRemove != nil {
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
)
|
||||
|
||||
type FileListDB struct {
|
||||
dbPath string
|
||||
|
||||
readDB *dbs.DB
|
||||
writeDB *dbs.DB
|
||||
|
||||
@@ -49,6 +51,8 @@ func NewFileListDB() *FileListDB {
|
||||
}
|
||||
|
||||
func (this *FileListDB) Open(dbPath string) error {
|
||||
this.dbPath = dbPath
|
||||
|
||||
// write db
|
||||
writeDB, err := sql.Open("sqlite3", "file:"+dbPath+"?cache=private&mode=rwc&_journal_mode=WAL&_sync=OFF&_cache_size=32000&_secure_delete=FAST")
|
||||
if err != nil {
|
||||
@@ -185,7 +189,7 @@ func (this *FileListDB) Add(hash string, item *Item) error {
|
||||
// 放入队列
|
||||
_, err := this.insertStmt.Exec(hash, item.Key, item.HeaderSize, item.BodySize, item.MetaSize, item.ExpiredAt, item.StaleAt, item.Host, item.ServerId, utils.UnixTime())
|
||||
if err != nil {
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -249,7 +253,7 @@ func (this *FileListDB) ListLFUItems(count int) (hashList []string, err error) {
|
||||
func (this *FileListDB) IncreaseHit(hash string) error {
|
||||
var week = timeutil.Format("YW")
|
||||
_, err := this.increaseHitStmt.Exec(hash, week, week, week, week)
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
func (this *FileListDB) CleanPrefix(prefix string) error {
|
||||
@@ -262,7 +266,7 @@ func (this *FileListDB) CleanPrefix(prefix string) error {
|
||||
for {
|
||||
result, err := this.writeDB.Exec(`UPDATE "`+this.itemsTableName+`" SET expiredAt=0,staleAt=? WHERE id IN (SELECT id FROM "`+this.itemsTableName+`" WHERE expiredAt>0 AND createdAt<=? AND INSTR("key", ?)=1 LIMIT `+types.String(count)+`)`, unixTime+int64(staleLife), unixTime, prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
affectedRows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
@@ -281,7 +285,7 @@ func (this *FileListDB) CleanAll() error {
|
||||
|
||||
_, err := this.deleteAllStmt.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -351,6 +355,13 @@ func (this *FileListDB) Close() error {
|
||||
return errors.New("close database failed: " + strings.Join(errStrings, ", "))
|
||||
}
|
||||
|
||||
func (this *FileListDB) WrapError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return errors.New(err.Error() + "(file: " + this.dbPath + ")")
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *FileListDB) initTables(times int) error {
|
||||
{
|
||||
@@ -393,10 +404,10 @@ ON "` + this.itemsTableName + `" (
|
||||
if dropErr == nil {
|
||||
return this.initTables(times + 1)
|
||||
}
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,10 +432,10 @@ ON "` + this.hitsTableName + `" (
|
||||
if dropErr == nil {
|
||||
return this.initTables(times + 1)
|
||||
}
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
|
||||
return err
|
||||
return this.WrapError(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ package caches
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// PartialRanges 内容分区范围定义
|
||||
@@ -30,7 +30,7 @@ func NewPartialRangesFromJSON(data []byte) (*PartialRanges, error) {
|
||||
}
|
||||
|
||||
func NewPartialRangesFromFile(path string) (*PartialRanges, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -116,12 +116,12 @@ func (this *PartialRanges) WriteToFile(path string) error {
|
||||
if err != nil {
|
||||
return errors.New("convert to json failed: " + err.Error())
|
||||
}
|
||||
return ioutil.WriteFile(path, data, 0666)
|
||||
return os.WriteFile(path, data, 0666)
|
||||
}
|
||||
|
||||
// ReadFromFile 从文件中读取
|
||||
func (this *PartialRanges) ReadFromFile(path string) (*PartialRanges, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -65,11 +65,13 @@ var sharedWritingFileKeyLocker = sync.Mutex{}
|
||||
|
||||
var maxOpenFiles = NewMaxOpenFiles()
|
||||
|
||||
const maxOpenFilesSlowCost = 500 * time.Microsecond // 0.5ms
|
||||
const maxOpenFilesSlowCost = 1000 * time.Microsecond // us
|
||||
const protectingLoadWhenDump = false
|
||||
|
||||
// FileStorage 文件缓存
|
||||
// 文件结构:
|
||||
// [expires time] | [ status ] | [url length] | [header length] | [body length] | [url] [header data] [body data]
|
||||
//
|
||||
// 文件结构:
|
||||
// [expires time] | [ status ] | [url length] | [header length] | [body length] | [url] [header data] [body data]
|
||||
type FileStorage struct {
|
||||
policy *serverconfigs.HTTPCachePolicy
|
||||
options *serverconfigs.HTTPFileCacheStorage // 二级缓存
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -152,7 +152,7 @@ func TestFileStorage_OpenWriter_HTTP(t *testing.T) {
|
||||
"Last-Modified": []string{"Wed, 06 Jan 2021 10:03:29 GMT"},
|
||||
"Server": []string{"CDN-Server"},
|
||||
},
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer([]byte("THIS IS HTTP BODY"))),
|
||||
Body: io.NopCloser(bytes.NewBuffer([]byte("THIS IS HTTP BODY"))),
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
|
||||
@@ -95,7 +95,8 @@ func (this *MemoryStorage) Init() error {
|
||||
// 启动定时Flush memory to disk任务
|
||||
if this.parentStorage != nil {
|
||||
// TODO 应该根据磁盘性能决定线程数
|
||||
var threads = 1
|
||||
// TODO 线程数应该可以在缓存策略和节点中设定
|
||||
var threads = runtime.NumCPU()
|
||||
|
||||
for i := 0; i < threads; i++ {
|
||||
goman.New(func() {
|
||||
@@ -438,16 +439,18 @@ func (this *MemoryStorage) startFlush() {
|
||||
if statCount == 100 {
|
||||
statCount = 0
|
||||
|
||||
loadStat, err := load.Avg()
|
||||
if err == nil && loadStat != nil {
|
||||
if loadStat.Load1 > 10 {
|
||||
writeDelayMS = 100
|
||||
} else if loadStat.Load1 > 3 {
|
||||
writeDelayMS = 50
|
||||
} else if loadStat.Load1 > 2 {
|
||||
writeDelayMS = 10
|
||||
} else {
|
||||
writeDelayMS = 0
|
||||
if protectingLoadWhenDump {
|
||||
loadStat, err := load.Avg()
|
||||
if err == nil && loadStat != nil {
|
||||
if loadStat.Load1 > 10 {
|
||||
writeDelayMS = 100
|
||||
} else if loadStat.Load1 > 3 {
|
||||
writeDelayMS = 50
|
||||
} else if loadStat.Load1 > 2 {
|
||||
writeDelayMS = 10
|
||||
} else {
|
||||
writeDelayMS = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -304,3 +305,30 @@ func TestMemoryStorage_Stop(t *testing.T) {
|
||||
|
||||
t.Log(len(m))
|
||||
}
|
||||
|
||||
func BenchmarkValuesMap(b *testing.B) {
|
||||
var m = map[uint64]*MemoryItem{}
|
||||
var count = 1_000_000
|
||||
for i := 0; i < count; i++ {
|
||||
m[uint64(i)] = &MemoryItem{
|
||||
ExpiresAt: time.Now().Unix(),
|
||||
}
|
||||
}
|
||||
b.Log(len(m))
|
||||
|
||||
var locker = sync.Mutex{}
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
locker.Lock()
|
||||
_, ok := m[uint64(rands.Int(0, 1_000_000))]
|
||||
_ = ok
|
||||
locker.Unlock()
|
||||
|
||||
locker.Lock()
|
||||
delete(m, uint64(rands.Int(2, 1000000)))
|
||||
locker.Unlock()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ package caches_test
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,7 +15,7 @@ func TestPartialFileWriter_Write(t *testing.T) {
|
||||
_ = os.Remove(path)
|
||||
|
||||
var reader = func() {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
var sharedBrotliReaderPool *ReaderPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
var sharedDeflateReaderPool *ReaderPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
var sharedGzipReaderPool *ReaderPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
var sharedZSTDReaderPool *ReaderPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -5,10 +5,46 @@ package compressions_test
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBrotliWriter_LargeFile(t *testing.T) {
|
||||
var data = []byte{}
|
||||
for i := 0; i < 1024*1024; i++ {
|
||||
data = append(data, stringutil.Rand(32)...)
|
||||
}
|
||||
t.Log(len(data)/1024/1024, "M")
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
var buf = &bytes.Buffer{}
|
||||
writer, err := compressions.NewBrotliWriter(buf, 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var offset = 0
|
||||
var size = 4096
|
||||
for offset < len(data) {
|
||||
_, err = writer.Write(data[offset : offset+size])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
offset += size
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBrotliWriter_Write(b *testing.B) {
|
||||
var data = []byte(strings.Repeat("A", 1024))
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/andybalholm/brotli"
|
||||
"io"
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
var sharedBrotliWriterPool *WriterPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -4,6 +4,7 @@ package compressions
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
var sharedDeflateWriterPool *WriterPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -4,6 +4,7 @@ package compressions
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
)
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
var sharedGzipWriterPool *WriterPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"io"
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
var sharedZSTDWriterPool *WriterPool
|
||||
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
var maxSize = utils.SystemMemoryGB() * 256
|
||||
if maxSize == 0 {
|
||||
maxSize = 256
|
||||
|
||||
@@ -3,20 +3,21 @@ package configs
|
||||
import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"gopkg.in/yaml.v3"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// APIConfig 节点API配置
|
||||
type APIConfig struct {
|
||||
RPC struct {
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate"`
|
||||
} `yaml:"rpc"`
|
||||
NodeId string `yaml:"nodeId"`
|
||||
Secret string `yaml:"secret"`
|
||||
}
|
||||
|
||||
func LoadAPIConfig() (*APIConfig, error) {
|
||||
data, err := ioutil.ReadFile(Tea.ConfigFile("api.yaml"))
|
||||
data, err := os.ReadFile(Tea.ConfigFile("api.yaml"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -30,12 +31,12 @@ func LoadAPIConfig() (*APIConfig, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
// WriteFile 保存到文件
|
||||
func (this *APIConfig) WriteFile(path string) error {
|
||||
data, err := yaml.Marshal(this)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(path, data, 0666)
|
||||
err = os.WriteFile(path, data, 0666)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package configs
|
||||
package configs_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadAPIConfig(t *testing.T) {
|
||||
config, err := LoadAPIConfig()
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(config)
|
||||
t.Logf("%+v", config)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package configs
|
||||
|
||||
// 集群配置
|
||||
// ClusterConfig 集群配置
|
||||
type ClusterConfig struct {
|
||||
RPC struct {
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate"`
|
||||
} `yaml:"rpc"`
|
||||
ClusterId string `yaml:"clusterId"`
|
||||
Secret string `yaml:"secret"`
|
||||
Secret string `yaml:"secret"`
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.4.9"
|
||||
Version = "0.5.1"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
package teaconst
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"os"
|
||||
)
|
||||
|
||||
var (
|
||||
// 流量统计
|
||||
@@ -12,6 +15,7 @@ var (
|
||||
|
||||
NodeId int64 = 0
|
||||
NodeIdString = ""
|
||||
IsDaemon = len(os.Args) > 1 && os.Args[1] == "daemon"
|
||||
|
||||
GlobalProductName = nodeconfigs.DefaultProductName
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -75,6 +76,24 @@ func (this *Firewalld) AllowPort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol string) error {
|
||||
for _, portRange := range portRanges {
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) RemovePort(port int, protocol string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
@@ -84,6 +103,30 @@ func (this *Firewalld) RemovePort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol string) error {
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) PortRangeString(portRange [2]int, protocol string) string {
|
||||
if portRange[0] == portRange[1] {
|
||||
return types.String(portRange[0]) + "/" + protocol
|
||||
} else {
|
||||
return types.String(portRange[0]) + "-" + types.String(portRange[1]) + "/" + protocol
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
@@ -101,7 +144,7 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
@@ -114,7 +157,15 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
|
||||
}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
if async {
|
||||
this.pushCmd(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("run command failed '" + cmd.String() + "': " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,10 @@ type FirewallInterface interface {
|
||||
RejectSourceIP(ip string, timeoutSeconds int) error
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
DropSourceIP(ip string, timeoutSeconds int) error
|
||||
// ip 要封禁的IP
|
||||
// timeoutSeconds 过期时间
|
||||
// async 是否异步
|
||||
DropSourceIP(ip string, timeoutSeconds int, async bool) error
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
RemoveSourceIP(ip string) error
|
||||
|
||||
@@ -47,7 +47,7 @@ func (this *MockFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *MockFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
func (this *MockFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
|
||||
_ = ip
|
||||
_ = timeoutSeconds
|
||||
return nil
|
||||
|
||||
@@ -7,8 +7,10 @@ package firewalls
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
@@ -21,9 +23,13 @@ import (
|
||||
|
||||
// check nft status, if being enabled we load it automatically
|
||||
func init() {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var ticker = time.NewTicker(3 * time.Minute)
|
||||
go func() {
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
// if already ready, we break
|
||||
if nftablesIsReady {
|
||||
@@ -48,7 +54,7 @@ func init() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,9 +80,16 @@ func (this *nftablesTableDefinition) protocol() string {
|
||||
return "ip"
|
||||
}
|
||||
|
||||
type blockIPItem struct {
|
||||
action string
|
||||
ip string
|
||||
timeoutSeconds int
|
||||
}
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
var firewall = &NFTablesFirewall{
|
||||
conn: nftables.NewConn(),
|
||||
conn: nftables.NewConn(),
|
||||
dropIPQueue: make(chan *blockIPItem, 4096),
|
||||
}
|
||||
err := firewall.init()
|
||||
if err != nil {
|
||||
@@ -98,6 +111,8 @@ type NFTablesFirewall struct {
|
||||
denyIPv6Set *nftables.Set
|
||||
|
||||
firewalld *Firewalld
|
||||
|
||||
dropIPQueue chan *blockIPItem
|
||||
}
|
||||
|
||||
func (this *NFTablesFirewall) init() error {
|
||||
@@ -243,6 +258,18 @@ func (this *NFTablesFirewall) init() error {
|
||||
nftablesIsReady = true
|
||||
nftablesInstance = this
|
||||
|
||||
goman.New(func() {
|
||||
for ipItem := range this.dropIPQueue {
|
||||
switch ipItem.action {
|
||||
case "drop":
|
||||
err = this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false)
|
||||
if err != nil {
|
||||
remotelogs.Warn("NFTABLES", "drop ip '"+ipItem.ip+"' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// load firewalld
|
||||
var firewalld = NewFirewalld()
|
||||
if firewalld.IsReady() {
|
||||
@@ -307,16 +334,29 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return this.DropSourceIP(ip, timeoutSeconds)
|
||||
return this.DropSourceIP(ip, timeoutSeconds, true)
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if async {
|
||||
select {
|
||||
case this.dropIPQueue <- &blockIPItem{
|
||||
action: "drop",
|
||||
ip: ip,
|
||||
timeoutSeconds: timeoutSeconds,
|
||||
}:
|
||||
default:
|
||||
return errors.New("drop ip queue is full")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
|
||||
@@ -51,7 +51,7 @@ func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) erro
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package goman
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -14,6 +15,10 @@ var instanceId = uint64(0)
|
||||
|
||||
// New 新创建goroutine
|
||||
func New(f func()) {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
@@ -42,6 +47,10 @@ func New(f func()) {
|
||||
|
||||
// NewWithArgs 创建带有参数的goroutine
|
||||
func NewWithArgs(f func(args ...interface{}), args ...interface{}) {
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -4,7 +4,7 @@ package iplibrary
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -62,7 +62,7 @@ func getIpInfo(cityId int64, line []byte) *IpInfo {
|
||||
|
||||
func NewIP2Region(path string) (*IP2Region, error) {
|
||||
var region = &IP2Region{}
|
||||
region.dbData, err = ioutil.ReadFile(path)
|
||||
region.dbData, err = os.ReadFile(path)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestIPItem_Memory(t *testing.T) {
|
||||
for i := 0; i < 2_000_000; i ++ {
|
||||
list.Add(&IPItem{
|
||||
Type: "ip",
|
||||
Id: int64(i),
|
||||
Id: uint64(i),
|
||||
IPFrom: utils.IP2Long("192.168.1.1"),
|
||||
IPTo: 0,
|
||||
ExpiredAt: time.Now().Unix(),
|
||||
|
||||
@@ -28,6 +28,8 @@ type IPListDB struct {
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
dir string
|
||||
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewIPListDB() (*IPListDB, error) {
|
||||
@@ -56,6 +58,12 @@ func (this *IPListDB) init() error {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
//_, err = db.Exec("VACUUM")
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
this.db = db
|
||||
|
||||
// 初始化数据库
|
||||
@@ -117,6 +125,7 @@ ON "` + this.itemTableName + `" (
|
||||
|
||||
goman.New(func() {
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = this.Close()
|
||||
this.cleanTicker.Stop()
|
||||
})
|
||||
|
||||
@@ -133,11 +142,19 @@ ON "` + this.itemTableName + `" (
|
||||
|
||||
// DeleteExpiredItems 删除过期的条目
|
||||
func (this *IPListDB) DeleteExpiredItems() error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := this.deleteItemStmt.Exec(item.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -147,6 +164,10 @@ func (this *IPListDB) AddItem(item *pb.IPItem) error {
|
||||
}
|
||||
|
||||
func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, err error) {
|
||||
if this.isClosed {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := this.selectItemsStmt.Query(offset, size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -169,6 +190,10 @@ func (this *IPListDB) ReadItems(offset int64, size int64) (items []*pb.IPItem, e
|
||||
|
||||
// ReadMaxVersion 读取当前最大版本号
|
||||
func (this *IPListDB) ReadMaxVersion() int64 {
|
||||
if this.isClosed {
|
||||
return 0
|
||||
}
|
||||
|
||||
row := this.selectMaxVersionStmt.QueryRow()
|
||||
if row == nil {
|
||||
return 0
|
||||
@@ -182,6 +207,8 @@ func (this *IPListDB) ReadMaxVersion() int64 {
|
||||
}
|
||||
|
||||
func (this *IPListDB) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
if this.db != nil {
|
||||
_ = this.deleteExpiredItemsStmt.Close()
|
||||
_ = this.deleteItemStmt.Close()
|
||||
|
||||
@@ -53,6 +53,11 @@ func TestIPListDB_ReadItems(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
items, err := db.ReadItems(0, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -144,7 +144,7 @@ func TestIPList_Contains(t *testing.T) {
|
||||
list := NewIPList()
|
||||
for i := 0; i < 255; i++ {
|
||||
list.AddDelay(&IPItem{
|
||||
Id: int64(i),
|
||||
Id: uint64(i),
|
||||
IPFrom: utils.IP2Long(strconv.Itoa(i) + ".168.0.1"),
|
||||
IPTo: utils.IP2Long(strconv.Itoa(i) + ".168.255.1"),
|
||||
ExpiredAt: 0,
|
||||
@@ -152,7 +152,7 @@ func TestIPList_Contains(t *testing.T) {
|
||||
}
|
||||
for i := 0; i < 255; i++ {
|
||||
list.AddDelay(&IPItem{
|
||||
Id: int64(1000 + i),
|
||||
Id: uint64(1000 + i),
|
||||
IPFrom: utils.IP2Long("192.167.2." + strconv.Itoa(i)),
|
||||
})
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func TestIPList_Contains_Many(t *testing.T) {
|
||||
list := NewIPList()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
list.AddDelay(&IPItem{
|
||||
Id: int64(i),
|
||||
Id: uint64(i),
|
||||
IPFrom: utils.IP2Long(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))),
|
||||
IPTo: utils.IP2Long(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))),
|
||||
ExpiredAt: 0,
|
||||
@@ -217,7 +217,7 @@ func TestIPList_ContainsIPStrings(t *testing.T) {
|
||||
list := NewIPList()
|
||||
for i := 0; i < 255; i++ {
|
||||
list.Add(&IPItem{
|
||||
Id: int64(i),
|
||||
Id: uint64(i),
|
||||
IPFrom: utils.IP2Long(strconv.Itoa(i) + ".168.0.1"),
|
||||
IPTo: utils.IP2Long(strconv.Itoa(i) + ".168.255.1"),
|
||||
ExpiredAt: 0,
|
||||
@@ -305,7 +305,7 @@ func BenchmarkIPList_Contains(b *testing.B) {
|
||||
var list = NewIPList()
|
||||
for i := 1; i < 200_000; i++ {
|
||||
list.AddDelay(&IPItem{
|
||||
Id: int64(i),
|
||||
Id: uint64(i),
|
||||
IPFrom: utils.IP2Long(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
|
||||
IPTo: utils.IP2Long(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
|
||||
ExpiredAt: time.Now().Unix() + 60,
|
||||
|
||||
@@ -106,6 +106,7 @@ func BenchmarkIP2RegionLibrary_Lookup(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = library.Lookup("8.8.8.8")
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -90,7 +89,7 @@ func (this *CityManager) Lookup(provinceId int64, cityName string) (cityId int64
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *CityManager) load() error {
|
||||
data, err := ioutil.ReadFile(this.cacheFile)
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -119,7 +118,7 @@ func (this *CityManager) loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionCityRPC().FindAllEnabledRegionCities(rpcClient.Context(), &pb.FindAllEnabledRegionCitiesRequest{})
|
||||
resp, err := rpcClient.RegionCityRPC().FindAllRegionCities(rpcClient.Context(), &pb.FindAllRegionCitiesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -151,6 +150,6 @@ func (this *CityManager) loop() error {
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = ioutil.WriteFile(this.cacheFile, data, 0666)
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -89,7 +88,7 @@ func (this *CountryManager) Lookup(countryName string) (countryId int64) {
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *CountryManager) load() error {
|
||||
data, err := ioutil.ReadFile(this.cacheFile)
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -118,7 +117,7 @@ func (this *CountryManager) loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionCountryRPC().FindAllEnabledRegionCountries(rpcClient.Context(), &pb.FindAllEnabledRegionCountriesRequest{})
|
||||
resp, err := rpcClient.RegionCountryRPC().FindAllRegionCountries(rpcClient.Context(), &pb.FindAllRegionCountriesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,6 +148,6 @@ func (this *CountryManager) loop() error {
|
||||
this.locker.Unlock()
|
||||
|
||||
// 保存到本地缓存
|
||||
err = ioutil.WriteFile(this.cacheFile, data, 0666)
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -89,7 +88,7 @@ func (this *ProviderManager) Lookup(providerName string) (providerId int64) {
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *ProviderManager) load() error {
|
||||
data, err := ioutil.ReadFile(this.cacheFile)
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -118,7 +117,7 @@ func (this *ProviderManager) loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionProviderRPC().FindAllEnabledRegionProviders(rpcClient.Context(), &pb.FindAllEnabledRegionProvidersRequest{})
|
||||
resp, err := rpcClient.RegionProviderRPC().FindAllRegionProviders(rpcClient.Context(), &pb.FindAllRegionProvidersRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -150,6 +149,6 @@ func (this *ProviderManager) loop() error {
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = ioutil.WriteFile(this.cacheFile, data, 0666)
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -93,7 +92,7 @@ func (this *ProvinceManager) Lookup(provinceName string) (provinceId int64) {
|
||||
|
||||
// 从缓存中读取
|
||||
func (this *ProvinceManager) load() error {
|
||||
data, err := ioutil.ReadFile(this.cacheFile)
|
||||
data, err := os.ReadFile(this.cacheFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
@@ -122,7 +121,7 @@ func (this *ProvinceManager) loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.RegionProvinceRPC().FindAllEnabledRegionProvincesWithCountryId(rpcClient.Context(), &pb.FindAllEnabledRegionProvincesWithCountryIdRequest{
|
||||
resp, err := rpcClient.RegionProvinceRPC().FindAllRegionProvincesWithRegionCountryId(rpcClient.Context(), &pb.FindAllRegionProvincesWithRegionCountryIdRequest{
|
||||
RegionCountryId: ChinaCountryId,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -156,6 +155,6 @@ func (this *ProvinceManager) loop() error {
|
||||
|
||||
// 保存到本地缓存
|
||||
|
||||
err = ioutil.WriteFile(this.cacheFile, data, 0666)
|
||||
err = os.WriteFile(this.cacheFile, data, 0666)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -23,18 +23,19 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxQueueSize = 2048
|
||||
const MaxQueueSize = 256 // TODO 可以配置,可以在单个任务里配置
|
||||
|
||||
// Task 单个指标任务
|
||||
// 数据库存储:
|
||||
// data/
|
||||
// metric.$ID.db
|
||||
// stats
|
||||
// id, keys, value, time, serverId, hash
|
||||
// 原理:
|
||||
// 添加或者有变更时 isUploaded = false
|
||||
// 上传时检查 isUploaded 状态
|
||||
// 只上传每个服务中排序最前面的 N 个数据
|
||||
//
|
||||
// data/
|
||||
// metric.$ID.db
|
||||
// stats
|
||||
// id, keys, value, time, serverId, hash
|
||||
// 原理:
|
||||
// 添加或者有变更时 isUploaded = false
|
||||
// 上传时检查 isUploaded 状态
|
||||
// 只上传每个服务中排序最前面的 N 个数据
|
||||
type Task struct {
|
||||
item *serverconfigs.MetricItemConfig
|
||||
isLoaded bool
|
||||
@@ -372,7 +373,9 @@ func (this *Task) Upload(pauseDuration time.Duration) error {
|
||||
for _, serverId := range serverIds {
|
||||
for _, currentTime := range times {
|
||||
idStrings, err := func(serverId int64, currentTime string) (ids []string, err error) {
|
||||
var t = trackers.Begin("[METRIC]SELECT_TOP_STMT")
|
||||
rows, err := this.selectTopStmt.Query(serverId, this.item.Version, currentTime)
|
||||
t.End()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -129,9 +129,8 @@ func (this *ClientConn) Close() error {
|
||||
err := this.rawConn.Close()
|
||||
|
||||
// 单个服务并发数限制
|
||||
if this.hasLimit {
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
}
|
||||
// 不能加条件限制,因为服务配置随时有变化
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
if beingDenied {
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
_ = fw.DropSourceIP(ip, 60)
|
||||
_ = fw.DropSourceIP(ip, 120, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
@@ -236,7 +235,7 @@ func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
|
||||
}()
|
||||
|
||||
// 读取内容,以便于生成缓存
|
||||
_, _ = io.Copy(ioutil.Discard, resp.Body)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
|
||||
// 处理502
|
||||
if resp.StatusCode == http.StatusBadGateway {
|
||||
|
||||
@@ -54,6 +54,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
}
|
||||
|
||||
var key = origin.UniqueKey() + "@" + originAddr
|
||||
var isLnRequest = origin.Id == 0
|
||||
|
||||
this.locker.RLock()
|
||||
client, found := this.clientsMap[key]
|
||||
@@ -101,6 +102,17 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
idleConns = numberCPU * 8
|
||||
}
|
||||
|
||||
// 可以判断为Ln节点请求
|
||||
if isLnRequest {
|
||||
maxConnections *= 8
|
||||
idleConns *= 8
|
||||
idleTimeout *= 4
|
||||
} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
|
||||
// Ln节点可以适当增加连接数
|
||||
maxConnections *= 2
|
||||
idleConns *= 2
|
||||
}
|
||||
|
||||
// TLS通讯
|
||||
var tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
@@ -149,6 +161,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ReadBufferSize: 8 * 1024,
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func TestHTTPClientPool_Client(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
@@ -48,8 +48,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pool := NewHTTPClientPool()
|
||||
pool.clientExpiredDuration = 2 * time.Second
|
||||
var pool = NewHTTPClientPool()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log("get", i)
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -151,7 +150,7 @@ func (this *HTTPRequest) Do() {
|
||||
// Web配置
|
||||
err := this.configureWeb(this.ReqServer.Web, true, 0)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to configure the server", "配置服务失败", false)
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
@@ -284,12 +283,12 @@ func (this *HTTPRequest) doBegin() {
|
||||
this.web.AccessLogRef.IsOn &&
|
||||
this.web.AccessLogRef.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody) {
|
||||
var err error
|
||||
this.requestBodyData, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, AccessLogMaxRequestBodySize))
|
||||
this.requestBodyData, err = io.ReadAll(io.LimitReader(this.RawReq.Body, AccessLogMaxRequestBodySize))
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, false)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read request body for access log", "为访问日志读取请求Body失败", false)
|
||||
return
|
||||
}
|
||||
this.RawReq.Body = ioutil.NopCloser(io.MultiReader(bytes.NewBuffer(this.requestBodyData), this.RawReq.Body))
|
||||
this.RawReq.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(this.requestBodyData), this.RawReq.Body))
|
||||
}
|
||||
|
||||
// 跳转
|
||||
@@ -1390,12 +1389,12 @@ func (this *HTTPRequest) Cookie(name string) string {
|
||||
return c.Value
|
||||
}
|
||||
|
||||
// DeleteHeader 删除Header
|
||||
// DeleteHeader 删除请求Header
|
||||
func (this *HTTPRequest) DeleteHeader(name string) {
|
||||
this.RawReq.Header.Del(name)
|
||||
}
|
||||
|
||||
// SetHeader 设置Header
|
||||
// SetHeader 设置请求Header
|
||||
func (this *HTTPRequest) SetHeader(name string, values []string) {
|
||||
this.RawReq.Header[name] = values
|
||||
}
|
||||
@@ -1712,7 +1711,12 @@ func (this *HTTPRequest) canIgnore(err error) bool {
|
||||
}
|
||||
|
||||
// 客户端主动取消
|
||||
if err == errWritingToClient || err == context.Canceled || err == io.ErrShortWrite || strings.Contains(err.Error(), "write: connection timed out") || strings.Contains(err.Error(), "write: broken pipe") {
|
||||
if err == errWritingToClient ||
|
||||
err == context.Canceled ||
|
||||
err == io.ErrShortWrite ||
|
||||
strings.Contains(err.Error(), "write: connection") ||
|
||||
strings.Contains(err.Error(), "write: broken pipe") ||
|
||||
strings.Contains(err.Error(), "write tcp") {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ package nodes
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -26,14 +26,14 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
subReq.Proto = this.RawReq.Proto
|
||||
subReq.ProtoMinor = this.RawReq.ProtoMinor
|
||||
subReq.ProtoMajor = this.RawReq.ProtoMajor
|
||||
subReq.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
subReq.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
subReq.Header.Set("Referer", this.URL())
|
||||
var writer = NewEmptyResponseWriter(this.writer)
|
||||
this.doSubRequest(writer, subReq)
|
||||
return writer.StatusCode(), nil
|
||||
}, this.Format)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to execute the AuthPolicy", "认证策略执行失败", false)
|
||||
return
|
||||
}
|
||||
if b {
|
||||
|
||||
@@ -1,32 +1,59 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const httpStatusPageTemplate = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>${status} ${statusMessage}</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h1>${status} ${statusMessage}</h1>
|
||||
<p>${message}</p>
|
||||
|
||||
<address>Request ID: ${requestId}.</address>
|
||||
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
func (this *HTTPRequest) write404() {
|
||||
if this.doPage(http.StatusNotFound) {
|
||||
this.writeCode(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) writeCode(statusCode int) {
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
|
||||
this.processResponseHeaders(http.StatusNotFound)
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte("404 page not found: '" + this.URL() + "'" + " (Request Id: " + this.requestId + ")"))
|
||||
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
|
||||
switch varName {
|
||||
case "status":
|
||||
return types.String(statusCode)
|
||||
case "statusMessage":
|
||||
return http.StatusText(statusCode)
|
||||
case "requestId":
|
||||
return this.requestId
|
||||
case "message":
|
||||
return "" // 空
|
||||
}
|
||||
return "${" + varName + "}"
|
||||
})
|
||||
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) writeCode(code int) {
|
||||
if this.doPage(code) {
|
||||
return
|
||||
}
|
||||
|
||||
this.processResponseHeaders(code)
|
||||
this.writer.WriteHeader(code)
|
||||
_, _ = this.writer.Write([]byte(types.String(code) + " " + http.StatusText(code) + ": '" + this.URL() + "'" + " (Request Id: " + this.requestId + ")"))
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) write50x(err error, statusCode int, canTryStale bool) {
|
||||
func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, zhMessage string, canTryStale bool) {
|
||||
if err != nil {
|
||||
this.addError(err)
|
||||
}
|
||||
@@ -37,7 +64,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, canTryStale bool) {
|
||||
this.web.Cache.Stale != nil &&
|
||||
this.web.Cache.Stale.IsOn &&
|
||||
(len(this.web.Cache.Stale.Status) == 0 || lists.ContainsInt(this.web.Cache.Stale.Status, statusCode)) {
|
||||
ok := this.doCacheRead(true)
|
||||
var ok = this.doCacheRead(true)
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
@@ -47,7 +74,34 @@ func (this *HTTPRequest) write50x(err error, statusCode int, canTryStale bool) {
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
|
||||
// 内置HTML模板
|
||||
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
|
||||
switch varName {
|
||||
case "status":
|
||||
return types.String(statusCode)
|
||||
case "statusMessage":
|
||||
return http.StatusText(statusCode)
|
||||
case "requestId":
|
||||
return this.requestId
|
||||
case "message":
|
||||
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
|
||||
if len(acceptLanguages) > 0 {
|
||||
var index = strings.Index(acceptLanguages, ",")
|
||||
if index > 0 {
|
||||
var firstLanguage = acceptLanguages[:index]
|
||||
if firstLanguage == "zh-CN" {
|
||||
return "网站出了一点小问题,原因:" + zhMessage + "。"
|
||||
}
|
||||
}
|
||||
}
|
||||
return "The site is unavailable now, cause: " + enMessage + "."
|
||||
}
|
||||
return "${" + varName + "}"
|
||||
})
|
||||
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(types.String(statusCode) + " " + http.StatusText(statusCode) + " (Request Id: " + this.requestId + ")"))
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
|
||||
|
||||
client, err := fcgi.SharedPool(fastcgi.Network(), fastcgi.RealAddress(), uint(poolSize)).Client()
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to create Fastcgi pool", "Fastcgi池生成失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,13 +159,13 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
|
||||
|
||||
resp, stderr, err := client.Call(fcgiReq)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
if len(stderr) > 0 {
|
||||
err := errors.New("Fastcgi Error: " + strings.TrimSpace(string(stderr)) + " script: " + maps.NewMap(params).GetString("SCRIPT_FILENAME"))
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,6 @@ func (this *HTTPRequest) checkLnRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) getLnOrigin() *serverconfigs.OriginConfig {
|
||||
return nil
|
||||
func (this *HTTPRequest) getLnOrigin(excludingNodeIds []int64) (originConfig *serverconfigs.OriginConfig, lnNodeId int64, hasMultipleNodes bool) {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
@@ -21,10 +21,32 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
return
|
||||
}
|
||||
|
||||
var retries = 3
|
||||
|
||||
var failedOriginIds []int64
|
||||
var failedLnNodeIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
originId, lnNodeId, shouldRetry := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if originId > 0 {
|
||||
failedOriginIds = append(failedOriginIds, originId)
|
||||
}
|
||||
if lnNodeId > 0 {
|
||||
failedLnNodeIds = append(failedLnNodeIds, lnNodeId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 请求源站
|
||||
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool) (originId int64, lnNodeId int64, shouldRetry bool) {
|
||||
// 对URL的处理
|
||||
var stripPrefix = this.reverseProxy.StripPrefix
|
||||
var requestURI = this.reverseProxy.RequestURI
|
||||
var requestURIHasVariables = this.reverseProxy.RequestURIHasVariables()
|
||||
var oldURI = this.uri
|
||||
|
||||
var requestHost = ""
|
||||
if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeCustomized {
|
||||
@@ -41,24 +63,39 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
var origin *serverconfigs.OriginConfig
|
||||
|
||||
// 二级节点
|
||||
var hasMultipleLnNodes = false
|
||||
if this.cacheRef != nil {
|
||||
origin = this.getLnOrigin()
|
||||
origin, lnNodeId, hasMultipleLnNodes = this.getLnOrigin(failedLnNodeIds)
|
||||
if origin != nil {
|
||||
// 强制变更原来访问的域名
|
||||
requestHost = this.ReqHost
|
||||
}
|
||||
|
||||
// 回源Header中去除If-None-Match和If-Modified-Since
|
||||
if !this.cacheRef.EnableIfNoneMatch {
|
||||
this.DeleteHeader("If-None-Match")
|
||||
}
|
||||
if !this.cacheRef.EnableIfModifiedSince {
|
||||
this.DeleteHeader("If-Modified-Since")
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义源站
|
||||
if origin == nil {
|
||||
origin = this.reverseProxy.NextOrigin(requestCall)
|
||||
if !isFirstTry {
|
||||
origin = this.reverseProxy.AnyOrigin(requestCall, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = this.reverseProxy.NextOrigin(requestCall)
|
||||
}
|
||||
requestCall.CallResponseCallbacks(this.writer)
|
||||
if origin == nil {
|
||||
err := errors.New(this.URL() + ": no available origin sites for reverse proxy")
|
||||
remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "No origin site yet", "尚未配置源站", true)
|
||||
return
|
||||
}
|
||||
originId = origin.Id
|
||||
|
||||
if len(origin.StripPrefix) > 0 {
|
||||
stripPrefix = origin.StripPrefix
|
||||
@@ -80,7 +117,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
if origin.Addr == nil {
|
||||
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
|
||||
return
|
||||
}
|
||||
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
|
||||
@@ -132,7 +169,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
if originHostIndex < 0 {
|
||||
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
|
||||
this.write50x(originErr, http.StatusBadGateway, true)
|
||||
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
|
||||
return
|
||||
}
|
||||
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
|
||||
@@ -211,12 +248,12 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
|
||||
return
|
||||
}
|
||||
|
||||
// 在HTTP/2下需要防止因为requestBody而导致Content-Length为空的问题
|
||||
if this.RawReq.ProtoMajor == 2 && this.RawReq.ContentLength == 0 {
|
||||
if this.RawReq.ProtoMajor == 2 && this.RawReq.ContentLength == 0 && this.RawReq.Body != nil {
|
||||
_ = this.RawReq.Body.Close()
|
||||
this.RawReq.Body = nil
|
||||
}
|
||||
@@ -230,18 +267,35 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+err.Error())
|
||||
} else if httpErr.Err != context.Canceled {
|
||||
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
// 是否需要重试
|
||||
if (originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) && !isLastRetry {
|
||||
shouldRetry = true
|
||||
this.uri = oldURI // 恢复备份
|
||||
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
if httpErr.Err != io.EOF {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if httpErr.Timeout() {
|
||||
this.write50x(err, http.StatusGatewayTimeout, true)
|
||||
this.write50x(err, http.StatusGatewayTimeout, "Read origin site timeout", "源站读取超时", true)
|
||||
} else if httpErr.Temporary() {
|
||||
this.write50x(err, http.StatusServiceUnavailable, true)
|
||||
this.write50x(err, http.StatusServiceUnavailable, "Origin site unavailable now", "源站当前不可用", true)
|
||||
} else {
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
if httpErr.Err != io.EOF {
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+err.Error())
|
||||
@@ -264,7 +318,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
}
|
||||
|
||||
if !isClientError {
|
||||
this.write50x(err, http.StatusBadGateway, true)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
@@ -410,4 +464,6 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
if (err == nil || err == io.EOF) && (closeErr == nil || closeErr == io.EOF) {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
@@ -285,7 +285,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
|
||||
fileReader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to open the file", "试图打开文件失败", true)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_URL", req.URL.String()+": "+err.Error())
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read url", "读取URL失败", false)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
|
||||
@@ -180,14 +180,14 @@ var httpRequestTimestamp int64
|
||||
var httpRequestId int32 = 1_000_000
|
||||
|
||||
func httpRequestNextId() string {
|
||||
var unixTime = utils.UnixTimeMilli()
|
||||
unixTime, unixTimeString := utils.UnixTimeMilliString()
|
||||
if unixTime > httpRequestTimestamp {
|
||||
atomic.StoreInt32(&httpRequestId, 1_000_000)
|
||||
httpRequestTimestamp = unixTime
|
||||
}
|
||||
|
||||
// timestamp + requestId + nodeId
|
||||
return strconv.FormatInt(unixTime, 10) + teaconst.NodeIdString + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1)))
|
||||
// timestamp + nodeId + requestId
|
||||
return unixTimeString + teaconst.NodeIdString + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1)))
|
||||
}
|
||||
|
||||
// 检查是否可以接受某个编码
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -177,7 +176,8 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.writeCode(http.StatusForbidden)
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
|
||||
// 停止日志
|
||||
@@ -197,7 +197,8 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.writeCode(http.StatusForbidden)
|
||||
this.writer.Flush()
|
||||
this.writer.Close()
|
||||
|
||||
// 停止日志
|
||||
@@ -357,7 +358,7 @@ func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
|
||||
// WAFReadBody 读取Body
|
||||
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
if this.RawReq.ContentLength > 0 {
|
||||
data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
|
||||
data, err = io.ReadAll(io.LimitReader(this.RawReq.Body, max))
|
||||
}
|
||||
|
||||
return
|
||||
@@ -366,7 +367,7 @@ func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
// WAFRestoreBody 恢复Body
|
||||
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
this.RawReq.Body = ioutil.NopCloser(io.MultiReader(bytes.NewBuffer(data), this.RawReq.Body))
|
||||
this.RawReq.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(data), this.RawReq.Body))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, false)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
|
||||
|
||||
// 增加失败次数
|
||||
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
|
||||
@@ -65,13 +65,13 @@ func (this *HTTPRequest) doWebsocket(requestHost string) {
|
||||
|
||||
err = this.RawReq.Write(originConn)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, false)
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := this.writer.Hijack()
|
||||
if err != nil || clientConn == nil {
|
||||
this.write50x(err, http.StatusInternalServerError, false)
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
@@ -524,7 +523,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) {
|
||||
|
||||
this.webpOriginContentType = contentType
|
||||
this.webpIsEncoding = true
|
||||
resp.Body = ioutil.NopCloser(&bytes.Buffer{})
|
||||
resp.Body = io.NopCloser(&bytes.Buffer{})
|
||||
this.delayRead = true
|
||||
|
||||
this.Header().Del("Content-Length")
|
||||
|
||||
@@ -5,11 +5,14 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
@@ -31,15 +34,19 @@ type ListenerManager struct {
|
||||
retryListenerMap map[string]*Listener // 需要重试的监听器 addr => Listener
|
||||
ticker *time.Ticker
|
||||
|
||||
lastPortStrings string
|
||||
firewalld *firewalls.Firewalld
|
||||
lastPortStrings string
|
||||
lastTCPPortRanges [][2]int
|
||||
lastUDPPortRanges [][2]int
|
||||
}
|
||||
|
||||
// NewListenerManager 获取新对象
|
||||
func NewListenerManager() *ListenerManager {
|
||||
manager := &ListenerManager{
|
||||
var manager = &ListenerManager{
|
||||
listenersMap: map[string]*Listener{},
|
||||
retryListenerMap: map[string]*Listener{},
|
||||
ticker: time.NewTicker(1 * time.Minute),
|
||||
firewalld: firewalls.NewFirewalld(),
|
||||
}
|
||||
|
||||
// 提升测试效率
|
||||
@@ -147,7 +154,7 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
|
||||
}
|
||||
|
||||
// 加入到firewalld
|
||||
this.addToFirewalld(groupAddrs)
|
||||
go this.addToFirewalld(groupAddrs)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -226,8 +233,14 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
|
||||
return
|
||||
}
|
||||
|
||||
if this.firewalld == nil || !this.firewalld.IsReady() {
|
||||
return
|
||||
}
|
||||
|
||||
// 组合端口号
|
||||
var ports = []string{}
|
||||
var portStrings = []string{}
|
||||
var udpPorts = []int{}
|
||||
var tcpPorts = []int{}
|
||||
for _, addr := range groupAddrs {
|
||||
var protocol = "tcp"
|
||||
if strings.HasPrefix(addr, "udp") {
|
||||
@@ -237,52 +250,72 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
|
||||
var lastIndex = strings.LastIndex(addr, ":")
|
||||
if lastIndex > 0 {
|
||||
var portString = addr[lastIndex+1:]
|
||||
ports = append(ports, portString+"/"+protocol)
|
||||
portStrings = append(portStrings, portString+"/"+protocol)
|
||||
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
tcpPorts = append(tcpPorts, types.Int(portString))
|
||||
case "udp":
|
||||
udpPorts = append(udpPorts, types.Int(portString))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
if len(portStrings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
sort.Strings(ports)
|
||||
var newPortStrings = strings.Join(ports, ",")
|
||||
sort.Strings(portStrings)
|
||||
var newPortStrings = strings.Join(portStrings, ",")
|
||||
if newPortStrings == this.lastPortStrings {
|
||||
return
|
||||
}
|
||||
this.lastPortStrings = newPortStrings
|
||||
|
||||
firewallCmd, err := exec.LookPath("firewall-cmd")
|
||||
if err != nil || len(firewallCmd) == 0 {
|
||||
return
|
||||
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
|
||||
defer func() {
|
||||
remotelogs.Println("FIREWALLD", "open ports successfully")
|
||||
}()
|
||||
|
||||
// 合并端口
|
||||
var tcpPortRanges = utils.MergePorts(tcpPorts)
|
||||
var udpPortRanges = utils.MergePorts(udpPorts)
|
||||
|
||||
defer func() {
|
||||
this.lastTCPPortRanges = tcpPortRanges
|
||||
this.lastUDPPortRanges = udpPortRanges
|
||||
}()
|
||||
|
||||
// 删除老的不存在的端口
|
||||
var tcpPortRangesMap = map[string]bool{}
|
||||
var udpPortRangesMap = map[string]bool{}
|
||||
for _, portRange := range tcpPortRanges {
|
||||
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
|
||||
}
|
||||
for _, portRange := range udpPortRanges {
|
||||
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
err = exec.Command(firewallCmd, "--state").Run()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("FIREWALLD", "open ports automatically")
|
||||
for _, port := range ports {
|
||||
{
|
||||
// TODO 需要支持sudo
|
||||
var cmd = exec.Command(firewallCmd, "--add-port="+port, "--permanent")
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
remotelogs.Warn("FIREWALLD", "'"+cmd.String()+"': "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// TODO 需要支持sudo
|
||||
var cmd = exec.Command(firewallCmd, "--add-port="+port)
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
remotelogs.Warn("FIREWALLD", "'"+cmd.String()+"': "+err.Error())
|
||||
return
|
||||
}
|
||||
for _, portRange := range this.lastTCPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "tcp")
|
||||
_, ok := tcpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
|
||||
}
|
||||
for _, portRange := range this.lastUDPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "udp")
|
||||
_, ok := udpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
|
||||
}
|
||||
|
||||
// 添加新的
|
||||
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
|
||||
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
@@ -108,8 +109,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
// 记录域名排行
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
var recordStat = false
|
||||
var serverName = ""
|
||||
if ok {
|
||||
var serverName = tlsConn.ConnectionState().ServerName
|
||||
serverName = tlsConn.ConnectionState().ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
@@ -122,8 +124,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String())
|
||||
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -218,37 +221,63 @@ func (this *TCPListener) Close() error {
|
||||
}
|
||||
|
||||
// 连接源站
|
||||
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
|
||||
var requestCall = shared.NewRequestCall()
|
||||
requestCall.Domain = requestHost
|
||||
|
||||
var retries = 3
|
||||
var addr string
|
||||
|
||||
var failedOriginIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
var origin = reverseProxy.NextOrigin(nil)
|
||||
var origin *serverconfigs.OriginConfig
|
||||
if len(failedOriginIds) > 0 {
|
||||
origin = reverseProxy.AnyOrigin(requestCall, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = reverseProxy.NextOrigin(requestCall)
|
||||
}
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 回源主机名
|
||||
var requestHost = ""
|
||||
if len(reverseProxy.RequestHost) > 0 {
|
||||
requestHost = reverseProxy.RequestHost
|
||||
}
|
||||
if len(origin.RequestHost) > 0 {
|
||||
requestHost = origin.RequestHost
|
||||
} else if len(reverseProxy.RequestHost) > 0 {
|
||||
requestHost = reverseProxy.RequestHost
|
||||
}
|
||||
|
||||
conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)
|
||||
if err != nil {
|
||||
failedOriginIds = append(failedOriginIds, origin.Id)
|
||||
|
||||
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
|
||||
|
||||
SharedOriginStateManager.Fail(origin, requestHost, reverseProxy, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
|
||||
if err == nil {
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
UDPConnLifeSeconds = 30
|
||||
)
|
||||
|
||||
type UDPListener struct {
|
||||
BaseListener
|
||||
|
||||
@@ -135,16 +139,39 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
||||
|
||||
var retries = 3
|
||||
var addr string
|
||||
|
||||
var failedOriginIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
var origin = reverseProxy.NextOrigin(nil)
|
||||
var origin *serverconfigs.OriginConfig
|
||||
if len(failedOriginIds) > 0 {
|
||||
origin = reverseProxy.AnyOrigin(nil, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = reverseProxy.NextOrigin(nil)
|
||||
}
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, addr, err = OriginConnect(origin, this.port, remoteAddr.String(), "")
|
||||
if err != nil {
|
||||
failedOriginIds = append(failedOriginIds, origin.Id)
|
||||
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
|
||||
|
||||
SharedOriginStateManager.Fail(origin, "", reverseProxy, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
// PROXY Protocol
|
||||
if reverseProxy != nil &&
|
||||
reverseProxy.ProxyProtocol != nil &&
|
||||
@@ -171,14 +198,17 @@ func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
||||
return
|
||||
}
|
||||
}
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
|
||||
if err == nil {
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 回收连接
|
||||
func (this *UDPListener) gcConns() {
|
||||
this.connLocker.Lock()
|
||||
closingConns := []*UDPConn{}
|
||||
var closingConns = []*UDPConn{}
|
||||
for addr, conn := range this.connMap {
|
||||
if !conn.IsOk() {
|
||||
closingConns = append(closingConns, conn)
|
||||
@@ -203,7 +233,7 @@ type UDPConn struct {
|
||||
}
|
||||
|
||||
func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn {
|
||||
conn := &UDPConn{
|
||||
var conn = &UDPConn{
|
||||
addr: addr,
|
||||
proxyConn: proxyConn,
|
||||
serverConn: serverConn,
|
||||
@@ -217,7 +247,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
buffer := utils.BytePool4k.Get()
|
||||
var buffer = utils.BytePool4k.Get()
|
||||
defer func() {
|
||||
utils.BytePool4k.Put(buffer)
|
||||
}()
|
||||
@@ -232,9 +262,13 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *ne
|
||||
break
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
// 记录流量和带宽
|
||||
if server != nil {
|
||||
// 流量
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
|
||||
// 带宽
|
||||
stats.SharedBandwidthStatManager.Add(server.UserId, server.Id, int64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -268,5 +302,5 @@ func (this *UDPConn) IsOk() bool {
|
||||
if !this.isOk {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix()-this.activatedAt < 30 // 如果超过 N 秒没有活动我们认为是超时
|
||||
return time.Now().Unix()-this.activatedAt < UDPConnLifeSeconds // 如果超过 N 秒没有活动我们认为是超时
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"gopkg.in/yaml.v3"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -207,8 +206,9 @@ func (this *Node) Start() {
|
||||
|
||||
// Daemon 实现守护进程
|
||||
func (this *Node) Daemon() {
|
||||
isDebug := lists.ContainsString(os.Args, "debug")
|
||||
isDebug = true
|
||||
teaconst.IsDaemon = true
|
||||
|
||||
var isDebug = lists.ContainsString(os.Args, "debug")
|
||||
for {
|
||||
conn, err := this.sock.Dial()
|
||||
if err != nil {
|
||||
@@ -227,13 +227,18 @@ func (this *Node) Daemon() {
|
||||
_ = os.Setenv("EdgeDaemon", "on")
|
||||
_ = os.Setenv("EdgeBackground", "on")
|
||||
|
||||
cmd := exec.Command(exe)
|
||||
var cmd = exec.Command(exe)
|
||||
var buf = &bytes.Buffer{}
|
||||
cmd.Stderr = buf
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]" + buf.String())
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -291,12 +296,20 @@ func (this *Node) loop() error {
|
||||
var nodeCtx = rpcClient.Context()
|
||||
tasksResp, err := rpcClient.NodeTaskRPC().FindNodeTasks(nodeCtx, &pb.FindNodeTasksRequest{})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) && !Tea.IsTesting() {
|
||||
return nil
|
||||
}
|
||||
return errors.New("read node tasks failed: " + err.Error())
|
||||
}
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
switch task.Type {
|
||||
case "ipItemChanged":
|
||||
iplibrary.IPListUpdateNotify <- true
|
||||
// 防止阻塞
|
||||
select {
|
||||
case iplibrary.IPListUpdateNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
@@ -333,11 +346,12 @@ func (this *Node) loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case "nodeVersionChanged":
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
if !sharedUpgradeManager.IsInstalling() {
|
||||
goman.New(func() {
|
||||
sharedUpgradeManager.Start()
|
||||
})
|
||||
}
|
||||
case "scriptsChanged":
|
||||
err = this.reloadCommonScripts()
|
||||
if err != nil {
|
||||
@@ -605,7 +619,7 @@ func (this *Node) startSyncTimer() {
|
||||
// 检查集群设置
|
||||
func (this *Node) checkClusterConfig() error {
|
||||
configFile := Tea.ConfigFile("cluster.yaml")
|
||||
data, err := ioutil.ReadFile(configFile)
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -635,11 +649,13 @@ func (this *Node) checkClusterConfig() error {
|
||||
if len(resp.Endpoints) == 0 {
|
||||
resp.Endpoints = []string{}
|
||||
}
|
||||
apiConfig := &configs.APIConfig{
|
||||
var apiConfig = &configs.APIConfig{
|
||||
RPC: struct {
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate"`
|
||||
}{
|
||||
Endpoints: resp.Endpoints,
|
||||
Endpoints: resp.Endpoints,
|
||||
DisableUpdate: false,
|
||||
},
|
||||
NodeId: resp.UniqueId,
|
||||
Secret: resp.Secret,
|
||||
@@ -673,7 +689,7 @@ func (this *Node) listenSock() error {
|
||||
if this.sock.IsListening() {
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err == nil {
|
||||
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
|
||||
return errors.New("error: the process is already running, pid: " + types.String(maps.NewMap(reply.Params).GetInt("pid")))
|
||||
} else {
|
||||
return errors.New("error: the process is already running")
|
||||
}
|
||||
@@ -711,6 +727,7 @@ func (this *Node) listenSock() error {
|
||||
_ = this.sock.Close()
|
||||
|
||||
events.Notify(events.EventQuit)
|
||||
events.Notify(events.EventTerminated)
|
||||
|
||||
// 监控连接数,如果连接数为0,则退出进程
|
||||
goman.New(func() {
|
||||
@@ -774,7 +791,8 @@ func (this *Node) listenSock() error {
|
||||
var m = maps.NewMap(cmd.Params)
|
||||
var ip = m.GetString("ip")
|
||||
var timeSeconds = m.GetInt("timeoutSeconds")
|
||||
err := firewalls.Firewall().DropSourceIP(ip, timeSeconds)
|
||||
var async = m.GetBool("async")
|
||||
err := firewalls.Firewall().DropSourceIP(ip, timeSeconds, async)
|
||||
if err != nil {
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Params: map[string]interface{}{
|
||||
@@ -997,7 +1015,7 @@ func (this *Node) checkDisk() {
|
||||
"/sys/block/vda/queue/rotational",
|
||||
"/sys/block/sda/queue/rotational",
|
||||
} {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
@@ -25,7 +24,7 @@ func (this *Node) handlePanic() {
|
||||
var panicFile = Tea.Root + "/logs/panic.log"
|
||||
|
||||
// 分析panic
|
||||
data, err := ioutil.ReadFile(panicFile)
|
||||
data, err := os.ReadFile(panicFile)
|
||||
if err == nil {
|
||||
var index = bytes.Index(data, []byte("panic:"))
|
||||
if index >= 0 {
|
||||
|
||||
@@ -30,10 +30,14 @@ type NodeStatusExecutor struct {
|
||||
cpuUpdatedTime time.Time
|
||||
cpuLogicalCount int
|
||||
cpuPhysicalCount int
|
||||
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewNodeStatusExecutor() *NodeStatusExecutor {
|
||||
return &NodeStatusExecutor{}
|
||||
return &NodeStatusExecutor{
|
||||
ticker: time.NewTicker(30 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) Listen() {
|
||||
@@ -41,15 +45,12 @@ func (this *NodeStatusExecutor) Listen() {
|
||||
this.cpuUpdatedTime = time.Now()
|
||||
this.update()
|
||||
|
||||
// TODO 这个时间间隔可以配置
|
||||
var ticker = time.NewTicker(30 * time.Second)
|
||||
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("NODE_STATUS", "quit executor")
|
||||
ticker.Stop()
|
||||
this.ticker.Stop()
|
||||
})
|
||||
|
||||
for range ticker.C {
|
||||
for range this.ticker.C {
|
||||
this.isFirstTime = false
|
||||
this.update()
|
||||
}
|
||||
@@ -68,6 +69,8 @@ func (this *NodeStatusExecutor) update() {
|
||||
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
|
||||
status.OS = runtime.GOOS
|
||||
status.Arch = runtime.GOARCH
|
||||
exe, _ := os.Executable()
|
||||
status.ExePath = exe
|
||||
status.ConfigVersion = sharedNodeConfig.Version
|
||||
status.IsActive = true
|
||||
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -104,7 +103,7 @@ func (this *SystemServiceManager) setupSystemd(params maps.Map) error {
|
||||
|
||||
if output == "enabled" {
|
||||
// 检查文件路径是否变化
|
||||
data, err := ioutil.ReadFile("/etc/systemd/system/" + teaconst.SystemdServiceName + ".service")
|
||||
data, err := os.ReadFile("/etc/systemd/system/" + teaconst.SystemdServiceName + ".service")
|
||||
if err == nil && bytes.Index(data, []byte(exe)) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -63,6 +63,16 @@ func (this *SyncAPINodesTask) Stop() {
|
||||
}
|
||||
|
||||
func (this *SyncAPINodesTask) Loop() error {
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 是否禁止自动升级
|
||||
if config.RPC.DisableUpdate {
|
||||
return nil
|
||||
}
|
||||
|
||||
var tr = trackers.Begin("SYNC_API_NODES")
|
||||
defer tr.End()
|
||||
|
||||
@@ -76,7 +86,7 @@ func (this *SyncAPINodesTask) Loop() error {
|
||||
return err
|
||||
}
|
||||
|
||||
newEndpoints := []string{}
|
||||
var newEndpoints = []string{}
|
||||
for _, node := range resp.ApiNodes {
|
||||
if !node.IsOn {
|
||||
continue
|
||||
@@ -85,16 +95,12 @@ func (this *SyncAPINodesTask) Loop() error {
|
||||
}
|
||||
|
||||
// 和现有的对比
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if this.isSame(newEndpoints, config.RPC.Endpoints) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 测试是否有API节点可用
|
||||
hasOk := this.testEndpoints(newEndpoints)
|
||||
var hasOk = this.testEndpoints(newEndpoints)
|
||||
if !hasOk {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,8 +14,10 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
@@ -46,15 +48,12 @@ func (this *UpgradeManager) Start() {
|
||||
}
|
||||
this.isInstalling = true
|
||||
|
||||
// 还原安装状态
|
||||
defer func() {
|
||||
this.isInstalling = false
|
||||
}()
|
||||
|
||||
remotelogs.Println("UPGRADE_MANAGER", "upgrading node ...")
|
||||
err := this.install()
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", "download failed: "+err.Error())
|
||||
|
||||
this.isInstalling = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,9 +64,16 @@ func (this *UpgradeManager) Start() {
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", err.Error())
|
||||
}
|
||||
|
||||
this.isInstalling = false
|
||||
})
|
||||
}
|
||||
|
||||
// IsInstalling 检查是否正在安装
|
||||
func (this *UpgradeManager) IsInstalling() bool {
|
||||
return this.isInstalling
|
||||
}
|
||||
|
||||
func (this *UpgradeManager) install() error {
|
||||
// 检查是否有已下载但未安装成功的
|
||||
if len(this.lastFile) > 0 {
|
||||
@@ -83,7 +89,7 @@ func (this *UpgradeManager) install() error {
|
||||
}
|
||||
|
||||
// 创建临时文件
|
||||
dir := Tea.Root + "/tmp"
|
||||
var dir = Tea.Root + "/tmp"
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
@@ -98,7 +104,7 @@ func (this *UpgradeManager) install() error {
|
||||
|
||||
remotelogs.Println("UPGRADE_MANAGER", "downloading new node ...")
|
||||
|
||||
path := dir + "/edge-node" + ".tmp"
|
||||
var path = dir + "/edge-node" + ".tmp"
|
||||
fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -203,16 +209,16 @@ func (this *UpgradeManager) unzip(zipPath string) error {
|
||||
}
|
||||
|
||||
// 先改先前的可执行文件
|
||||
err := os.Rename(target+"/bin/edge-node", target+"/bin/.edge-node.dist")
|
||||
hasBackup := err == nil
|
||||
err := os.Rename(target+"/bin/"+teaconst.ProcessName, target+"/bin/."+teaconst.ProcessName+".dist")
|
||||
var hasBackup = err == nil
|
||||
defer func() {
|
||||
if !isOk && hasBackup {
|
||||
// 失败时还原
|
||||
_ = os.Rename(target+"/bin/.edge-node.dist", target+"/bin/edge-node")
|
||||
_ = os.Rename(target+"/bin/."+teaconst.ProcessName+".dist", target+"/bin/"+teaconst.ProcessName)
|
||||
}
|
||||
}()
|
||||
|
||||
unzip := utils.NewUnzip(zipPath, target, "edge-node/")
|
||||
var unzip = utils.NewUnzip(zipPath, target, "edge-node/")
|
||||
err = unzip.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -225,6 +231,9 @@ func (this *UpgradeManager) unzip(zipPath string) error {
|
||||
|
||||
// 重启
|
||||
func (this *UpgradeManager) restart() error {
|
||||
// 关闭当前sock,防止无法重启
|
||||
_ = gosock.NewTmpSock(teaconst.ProcessName).Close()
|
||||
|
||||
// 重新启动
|
||||
if DaemonIsOn && DaemonPid == os.Getppid() {
|
||||
utils.Exit() // TODO 试着更优雅重启
|
||||
@@ -241,7 +250,9 @@ func (this *UpgradeManager) restart() error {
|
||||
events.Notify(events.EventTerminated)
|
||||
|
||||
// 启动
|
||||
cmd := exec.Command(exe, "start")
|
||||
exe = filepath.Dir(exe) + "/" + teaconst.ProcessName
|
||||
|
||||
var cmd = exec.Command(exe, "start")
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -109,6 +109,8 @@ func TestRegexp_ParseKeywords3(t *testing.T) {
|
||||
|
||||
func BenchmarkRegexp_MatchString(b *testing.B) {
|
||||
var r = re.MustCompile("(?i)(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)(\\s|%09|%0A|(\\+|%20))*(=|%3D)")
|
||||
b.ResetTimer()
|
||||
|
||||
//b.Log("keywords:", r.Keywords())
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36")
|
||||
@@ -117,6 +119,8 @@ func BenchmarkRegexp_MatchString(b *testing.B) {
|
||||
|
||||
func BenchmarkRegexp_MatchString2(b *testing.B) {
|
||||
var r = regexp.MustCompile("(?i)(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)(\\s|%09|%0A|(\\+|%20))*(=|%3D)")
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
r.MatchString("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36")
|
||||
}
|
||||
|
||||
@@ -343,9 +343,11 @@ func (this *HTTPRequestStatManager) Upload() error {
|
||||
if strings.Contains(err.Error(), "string field contains invalid UTF-8") {
|
||||
for _, system := range pbSystems {
|
||||
system.Name = utils.ToValidUTF8string(system.Name)
|
||||
system.Version = utils.ToValidUTF8string(system.Version)
|
||||
}
|
||||
for _, browser := range pbBrowsers {
|
||||
browser.Name = utils.ToValidUTF8string(browser.Name)
|
||||
browser.Version = utils.ToValidUTF8string(browser.Version)
|
||||
}
|
||||
|
||||
// 再次尝试
|
||||
|
||||
@@ -106,7 +106,7 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
|
||||
this.totalRequests++
|
||||
|
||||
var timestamp = utils.FloorUnixTime(300)
|
||||
key := strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
|
||||
var key = strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
|
||||
this.locker.Lock()
|
||||
|
||||
// 总的流量
|
||||
|
||||
@@ -1,108 +1,46 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"time"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var BytePool1k = NewBytePool(20480, 1024)
|
||||
var BytePool4k = NewBytePool(20480, 4*1024)
|
||||
var BytePool16k = NewBytePool(40960, 16*1024)
|
||||
var BytePool32k = NewBytePool(20480, 32*1024)
|
||||
var BytePool1k = NewBytePool(1024)
|
||||
var BytePool4k = NewBytePool(4 * 1024)
|
||||
var BytePool16k = NewBytePool(16 * 1024)
|
||||
var BytePool32k = NewBytePool(32 * 1024)
|
||||
|
||||
// BytePool pool for get byte slice
|
||||
type BytePool struct {
|
||||
c chan []byte
|
||||
maxSize int
|
||||
length int
|
||||
hasNew bool
|
||||
rawPool *sync.Pool
|
||||
}
|
||||
|
||||
// NewBytePool 创建新对象
|
||||
func NewBytePool(maxSize, length int) *BytePool {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1024
|
||||
func NewBytePool(length int) *BytePool {
|
||||
if length < 0 {
|
||||
length = 1024
|
||||
}
|
||||
if length <= 0 {
|
||||
length = 128
|
||||
return &BytePool{
|
||||
length: length,
|
||||
rawPool: &sync.Pool{
|
||||
New: func() any {
|
||||
return make([]byte, length)
|
||||
},
|
||||
},
|
||||
}
|
||||
var pool = &BytePool{
|
||||
c: make(chan []byte, maxSize),
|
||||
maxSize: maxSize,
|
||||
length: length,
|
||||
}
|
||||
|
||||
pool.init()
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *BytePool) init() {
|
||||
var ticker = time.NewTicker(2 * time.Minute)
|
||||
if Tea.IsTesting() {
|
||||
ticker = time.NewTicker(5 * time.Second)
|
||||
}
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
if this.hasNew {
|
||||
this.hasNew = false
|
||||
continue
|
||||
}
|
||||
|
||||
this.Purge()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Get 获取一个新的byte slice
|
||||
func (this *BytePool) Get() (b []byte) {
|
||||
select {
|
||||
case b = <-this.c:
|
||||
default:
|
||||
b = make([]byte, this.length)
|
||||
this.hasNew = true
|
||||
}
|
||||
return
|
||||
func (this *BytePool) Get() []byte {
|
||||
return this.rawPool.Get().([]byte)
|
||||
}
|
||||
|
||||
// Put 放回一个使用过的byte slice
|
||||
func (this *BytePool) Put(b []byte) {
|
||||
if cap(b) != this.length {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case this.c <- b:
|
||||
default:
|
||||
// 已达最大容量,则抛弃
|
||||
}
|
||||
this.rawPool.Put(b)
|
||||
}
|
||||
|
||||
// Length 单个字节slice长度
|
||||
func (this *BytePool) Length() int {
|
||||
return this.length
|
||||
}
|
||||
|
||||
// Size 当前的数量
|
||||
func (this *BytePool) Size() int {
|
||||
return len(this.c)
|
||||
}
|
||||
|
||||
// Purge 清理
|
||||
func (this *BytePool) Purge() {
|
||||
// 1%
|
||||
var count = len(this.c) / 100
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
Loop:
|
||||
for i := 0; i < count; i++ {
|
||||
select {
|
||||
case <-this.c:
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +1,16 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewBytePool(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
pool := NewBytePool(5, 8)
|
||||
buf := pool.Get()
|
||||
a.IsTrue(len(buf) == 8)
|
||||
a.IsTrue(len(pool.c) == 0)
|
||||
|
||||
pool.Put(buf)
|
||||
a.IsTrue(len(pool.c) == 1)
|
||||
|
||||
pool.Get()
|
||||
a.IsTrue(len(pool.c) == 0)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
pool.Put(buf)
|
||||
}
|
||||
t.Log(len(pool.c))
|
||||
a.IsTrue(len(pool.c) == 5)
|
||||
}
|
||||
|
||||
func TestBytePool_Memory(t *testing.T) {
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
var pool = NewBytePool(20480, 32*1024)
|
||||
var pool = NewBytePool(32 * 1024)
|
||||
for i := 0; i < 20480; i++ {
|
||||
pool.Put(make([]byte, 32*1024))
|
||||
}
|
||||
@@ -44,18 +23,50 @@ func TestBytePool_Memory(t *testing.T) {
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB,", pool.Size(), "slices")
|
||||
t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB,")
|
||||
}
|
||||
|
||||
func BenchmarkBytePool_Get(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
pool := NewBytePool(1024, 1)
|
||||
var pool = NewBytePool(1)
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := pool.Get()
|
||||
var buf = pool.Get()
|
||||
_ = buf
|
||||
pool.Put(buf)
|
||||
}
|
||||
|
||||
b.Log(pool.Size())
|
||||
}
|
||||
|
||||
func BenchmarkBytePool_Get_Parallel(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var pool = NewBytePool(1024)
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var buf = pool.Get()
|
||||
pool.Put(buf)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePool_Get_Sync(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var pool = &sync.Pool{
|
||||
New: func() any {
|
||||
return make([]byte, 1024)
|
||||
},
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var buf = pool.Get()
|
||||
pool.Put(buf)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
@@ -19,7 +19,7 @@ func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -25,16 +25,3 @@ func ListenReuseAddr(network string, addr string) (net.Listener, error) {
|
||||
}
|
||||
return config.Listen(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
// ParseAddrHost 分析地址中的主机名部分
|
||||
func ParseAddrHost(addr string) string {
|
||||
if len(addr) == 0 {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
@@ -12,3 +12,15 @@ func TestParseAddrHost(t *testing.T) {
|
||||
t.Log(addr + " => " + utils.ParseAddrHost(addr))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergePorts(t *testing.T) {
|
||||
for _, ports := range [][]int{
|
||||
{},
|
||||
{80},
|
||||
{80, 83, 85},
|
||||
{80, 81, 83, 85, 86, 87, 88, 90},
|
||||
{0, 0, 1, 1, 2, 2, 2, 3, 3, 3},
|
||||
} {
|
||||
t.Log(ports, "=>", utils.MergePorts(ports))
|
||||
}
|
||||
}
|
||||
|
||||
52
internal/utils/net_utils.go
Normal file
52
internal/utils/net_utils.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// ParseAddrHost 分析地址中的主机名部分
|
||||
func ParseAddrHost(addr string) string {
|
||||
if len(addr) == 0 {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// MergePorts 聚合端口
|
||||
// 返回 [ [fromPort, toPort], ... ]
|
||||
func MergePorts(ports []int) [][2]int {
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Ints(ports)
|
||||
|
||||
var result = [][2]int{}
|
||||
var lastRange = [2]int{0, 0}
|
||||
var lastPort = -1
|
||||
for _, port := range ports {
|
||||
if port <= 0 /** 只处理有效的端口 **/ || port == lastPort /** 去重 **/ {
|
||||
continue
|
||||
}
|
||||
|
||||
if lastPort < 0 || port != lastPort+1 {
|
||||
lastRange = [2]int{port, port}
|
||||
result = append(result, lastRange)
|
||||
} else { // 如果是连续的
|
||||
lastRange[1] = port
|
||||
result[len(result)-1] = lastRange
|
||||
}
|
||||
|
||||
lastPort = port
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/textproto"
|
||||
"testing"
|
||||
)
|
||||
@@ -17,7 +16,7 @@ func TestNewByteRangesReader(t *testing.T) {
|
||||
var dashBoundary = "--" + boundary
|
||||
var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n"))
|
||||
|
||||
var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary)
|
||||
var reader = readers.NewByteRangesReaderCloser(io.NopCloser(b), boundary)
|
||||
var p = make([]byte, 16)
|
||||
for {
|
||||
n, err := reader.Read(p)
|
||||
@@ -38,7 +37,7 @@ func TestByteRangesReader_OnPartRead(t *testing.T) {
|
||||
var dashBoundary = "--" + boundary
|
||||
var b = bytes.NewReader([]byte(dashBoundary + "\r\nContent-Range: bytes 0-4/36\r\nContent-Type: text/plain\r\n\r\n01234\r\n" + dashBoundary + "\r\nContent-Range: bytes 5-9/36\r\nContent-Type: text/plain\r\n\r\n56789\r\n--" + boundary + "\r\nContent-Range: bytes 10-12/36\r\nContent-Type: text/plain\r\n\r\nabc\r\n" + dashBoundary + "--\r\n"))
|
||||
|
||||
var reader = readers.NewByteRangesReaderCloser(ioutil.NopCloser(b), boundary)
|
||||
var reader = readers.NewByteRangesReaderCloser(io.NopCloser(b), boundary)
|
||||
reader.OnPartRead(func(start int64, end int64, total int64, data []byte, header textproto.MIMEHeader) {
|
||||
t.Log(start, "-", end, "/", total, string(data))
|
||||
})
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/files"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
@@ -83,13 +82,13 @@ func (this *ServiceManager) installInitService(exePath string, args []string) er
|
||||
return errors.New("'scripts/" + shortName + "' file not exists")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(scriptFile)
|
||||
data, err := os.ReadFile(scriptFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data = regexp.MustCompile("INSTALL_DIR=.+").ReplaceAll(data, []byte("INSTALL_DIR="+Tea.Root))
|
||||
err = ioutil.WriteFile(initServiceFile, data, 0777)
|
||||
err = os.WriteFile(initServiceFile, data, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -137,7 +136,7 @@ ExecReload=` + exePath + ` reload
|
||||
WantedBy=multi-user.target`
|
||||
|
||||
// write file
|
||||
err := ioutil.WriteFile(systemdServiceFile, []byte(desc), 0777)
|
||||
err := os.WriteFile(systemdServiceFile, []byte(desc), 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,18 +2,21 @@ package utils
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
var unixTime = time.Now().Unix()
|
||||
var unixTimeMilli = time.Now().UnixMilli()
|
||||
var unixTimeMilliString = types.String(unixTimeMilli)
|
||||
|
||||
func init() {
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
var ticker = time.NewTicker(200 * time.Millisecond)
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
unixTime = time.Now().Unix()
|
||||
unixTimeMilli = time.Now().UnixMilli()
|
||||
unixTimeMilliString = types.String(unixTimeMilli)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -43,6 +46,10 @@ func UnixTimeMilli() int64 {
|
||||
return unixTimeMilli
|
||||
}
|
||||
|
||||
func UnixTimeMilliString() (int64, string) {
|
||||
return unixTimeMilli, unixTimeMilliString
|
||||
}
|
||||
|
||||
// GMTUnixTime 计算GMT时间戳
|
||||
func GMTUnixTime(timestamp int64) int64 {
|
||||
_, offset := time.Now().Zone()
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -111,7 +110,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
|
||||
path = Tea.Root + string(os.PathSeparator) + path
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return false
|
||||
|
||||
@@ -30,7 +30,7 @@ func (this *CCCheckpoint) Start() {
|
||||
this.cache = ttlcache.NewCache()
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = 0
|
||||
|
||||
if this.cache == nil {
|
||||
@@ -114,15 +114,15 @@ func (this *CCCheckpoint) RequestValue(req requests.Request, param string, optio
|
||||
if len(key) == 0 {
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period, false)
|
||||
value = this.cache.IncreaseInt64(types.String(ruleId)+"@"+key, int64(1), time.Now().Unix()+period, false)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ type CC2Checkpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
var keys = options.GetSlice("keys")
|
||||
var keyValues = []string{}
|
||||
for _, key := range keys {
|
||||
@@ -66,11 +66,16 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
|
||||
}
|
||||
}
|
||||
|
||||
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period, false)
|
||||
var ccKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(keyValues, "@")
|
||||
value = ccCache.IncreaseInt64(ccKey, 1, time.Now().Unix()+period, false)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -23,21 +23,21 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
|
||||
options := maps.Map{
|
||||
"period": "5",
|
||||
}
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.1"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options, 1))
|
||||
}
|
||||
|
||||
@@ -17,10 +17,10 @@ type CheckpointInterface interface {
|
||||
IsComposed() bool
|
||||
|
||||
// RequestValue get request value
|
||||
RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error)
|
||||
RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error)
|
||||
|
||||
// ResponseValue get response value
|
||||
ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error)
|
||||
ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error)
|
||||
|
||||
// ParamOptions param option list
|
||||
ParamOptions() *ParamOptions
|
||||
|
||||
@@ -11,7 +11,7 @@ type RequestAllCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
valueBytes := []byte{}
|
||||
if len(req.WAFRaw().RequestURI) > 0 {
|
||||
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
|
||||
@@ -47,10 +47,10 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = ""
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
v, _, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
v, _, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil, 1)
|
||||
if sysErr != nil {
|
||||
t.Fatal(sysErr)
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Log(v)
|
||||
t.Log(types.String(v))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -42,13 +42,13 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("value bytes:", len(types.String(value)))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
_, _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil, 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ type RequestArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
return req.WAFRaw().URL.Query().Get(param), hasRequestBody, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func TestArgParam_RequestValue(t *testing.T) {
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
|
||||
checkpoint := new(RequestArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
t.Log(checkpoint.ResponseValue(req, nil, "name", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "name2", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil, 1))
|
||||
t.Log(checkpoint.ResponseValue(req, nil, "name", nil, 1))
|
||||
t.Log(checkpoint.RequestValue(req, "name2", nil, 1))
|
||||
}
|
||||
|
||||
@@ -9,14 +9,14 @@ type RequestArgsCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().URL.RawQuery
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ type RequestBodyCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.RequestBodyIsEmpty(req) {
|
||||
value = ""
|
||||
return
|
||||
@@ -38,9 +38,9 @@ func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param stri
|
||||
return bodyData, hasRequestBody, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -17,9 +17,9 @@ func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
|
||||
}
|
||||
var req = requests.NewTestRequest(rawReq)
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "", nil, 1))
|
||||
|
||||
body, err := ioutil.ReadAll(rawReq.Body)
|
||||
body, err := io.ReadAll(rawReq.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -34,13 +34,13 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("value bytes:", len(types.String(value)))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,14 +9,14 @@ type RequestContentTypeCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Header.Get("Content-Type")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ type RequestCookieCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
cookie, err := req.WAFRaw().Cookie(param)
|
||||
if err != nil {
|
||||
value = ""
|
||||
@@ -20,9 +20,9 @@ func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param st
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ type RequestCookiesCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
var cookies = []string{}
|
||||
for _, cookie := range req.WAFRaw().Cookies() {
|
||||
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
|
||||
@@ -20,9 +20,9 @@ func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ type RequestFormArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
hasRequestBody = true
|
||||
|
||||
if this.RequestBodyIsEmpty(req) {
|
||||
@@ -42,9 +42,9 @@ func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param s
|
||||
return values.Get(param), hasRequestBody, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package checkpoints
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
@@ -19,12 +19,12 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||
req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestFormArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "age", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "Hello", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "encoded", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil, 1))
|
||||
t.Log(checkpoint.RequestValue(req, "age", nil, 1))
|
||||
t.Log(checkpoint.RequestValue(req, "Hello", nil, 1))
|
||||
t.Log(checkpoint.RequestValue(req, "encoded", nil, 1))
|
||||
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
body, err := io.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = false
|
||||
|
||||
var headers = options.GetSlice("headers")
|
||||
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Requ
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ func (this *RequestGeoCityNameCheckpoint) IsComposed() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
value = req.Format("${geo.city.name}")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
return this.RequestValue(req, param, options)
|
||||
func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value interface{}, hasRequestBody bool, sysErr error, userErr error) {
|
||||
return this.RequestValue(req, param, options, ruleId)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user