Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
684ba7082b | ||
|
|
8934962de2 | ||
|
|
0cf37f25dc | ||
|
|
d7a6d71fea | ||
|
|
a26f7941d5 | ||
|
|
f5365e5420 | ||
|
|
56d21f867b | ||
|
|
d18a301c61 | ||
|
|
afb937030c | ||
|
|
8faa82c453 | ||
|
|
b17b63aec5 | ||
|
|
c30dbb811f | ||
|
|
a58816361e | ||
|
|
8d37aefd95 | ||
|
|
df7fee966e | ||
|
|
01cfccebbd | ||
|
|
6bd7da5e6e | ||
|
|
dcba9c2f3e | ||
|
|
f38e80e82d | ||
|
|
9bd38094c3 | ||
|
|
7e37fc3b80 | ||
|
|
d775dfeeaa | ||
|
|
0486f86898 | ||
|
|
102157c893 | ||
|
|
dbd92368ae | ||
|
|
9e418e73bf | ||
|
|
5d40eec163 | ||
|
|
1a7a67238d | ||
|
|
6707437bae | ||
|
|
df7859387d | ||
|
|
889c52330d |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*_plus.go
|
||||
*-plus.sh
|
||||
@@ -7,6 +7,7 @@ function build() {
|
||||
DIST=$ROOT/"../dist/${NAME}"
|
||||
OS=${1}
|
||||
ARCH=${2}
|
||||
TAG=${3}
|
||||
|
||||
if [ -z $OS ]; then
|
||||
echo "usage: build.sh OS ARCH"
|
||||
@@ -16,6 +17,9 @@ function build() {
|
||||
echo "usage: build.sh OS ARCH"
|
||||
exit
|
||||
fi
|
||||
if [ -z $TAG ]; then
|
||||
TAG="community"
|
||||
fi
|
||||
|
||||
echo "checking ..."
|
||||
ZIP_PATH=$(which zip)
|
||||
@@ -24,8 +28,8 @@ function build() {
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "building v${VERSION}/${OS}/${ARCH} ..."
|
||||
ZIP="${NAME}-${OS}-${ARCH}-v${VERSION}.zip"
|
||||
echo "building v${VERSION}/${OS}/${ARCH}/${TAG} ..."
|
||||
ZIP="${NAME}-${OS}-${ARCH}-${TAG}-v${VERSION}.zip"
|
||||
|
||||
echo "copying ..."
|
||||
if [ ! -d $DIST ]; then
|
||||
@@ -66,6 +70,10 @@ function build() {
|
||||
CC_PATH="aarch64-linux-musl-gcc"
|
||||
CXX_PATH="aarch64-linux-musl-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "arm" ]; then
|
||||
CC_PATH="arm-linux-musleabi-gcc"
|
||||
CXX_PATH="arm-linux-musleabi-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "mips64" ]; then
|
||||
CC_PATH="mips64-linux-musl-gcc"
|
||||
CXX_PATH="mips64-linux-musl-g++"
|
||||
@@ -76,9 +84,9 @@ function build() {
|
||||
fi
|
||||
fi
|
||||
if [ ! -z $CC_PATH ]; then
|
||||
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -o $DIST/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" $ROOT/../cmd/edge-node/main.go
|
||||
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -tags $TAG -o $DIST/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" $ROOT/../cmd/edge-node/main.go
|
||||
else
|
||||
env GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go
|
||||
env GOOS=${OS} GOARCH=${ARCH} CGO_ENABLED=1 go build -tags $TAG -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go
|
||||
fi
|
||||
|
||||
# delete hidden files
|
||||
@@ -110,4 +118,4 @@ function lookup-version() {
|
||||
fi
|
||||
}
|
||||
|
||||
build $1 $2
|
||||
build $1 $2 $3
|
||||
|
||||
2
build/data/.gitignore
vendored
2
build/data/.gitignore
vendored
@@ -1 +1 @@
|
||||
index.*
|
||||
*
|
||||
@@ -5,15 +5,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/apps"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/nodes"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -40,25 +37,13 @@ func main() {
|
||||
fmt.Println("done")
|
||||
})
|
||||
app.On("quit", func() {
|
||||
pidFile := Tea.Root + "/bin/pid"
|
||||
data, err := ioutil.ReadFile(pidFile)
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
_, err := sock.Send(&gosock.Command{Code: "quit"})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]quit failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
pid := types.Int(string(data))
|
||||
if pid == 0 {
|
||||
fmt.Println("[ERROR]quit failed: pid=0")
|
||||
return
|
||||
}
|
||||
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if process != nil {
|
||||
_ = process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
fmt.Println("done")
|
||||
})
|
||||
app.On("pprof", func() {
|
||||
// TODO 自己指定端口
|
||||
|
||||
5
go.mod
5
go.mod
@@ -12,10 +12,11 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/golang/protobuf v1.5.2
|
||||
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11
|
||||
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible
|
||||
github.com/mattn/go-sqlite3 v1.14.7
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
github.com/mssola/user_agent v0.5.2
|
||||
github.com/shirou/gopsutil v3.21.5+incompatible
|
||||
github.com/tklauser/go-sysconf v0.3.6 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -48,7 +48,6 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU
|
||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
|
||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
|
||||
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
|
||||
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
|
||||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||
@@ -57,16 +56,18 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f h1:r2O8PONj/KiuZjJHVHn7KlCePUIjNtgAmvLfgRafQ8o=
|
||||
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060 h1:qdLtK4PDXxk2vMKkTWl5Fl9xqYuRCukzWAgJbLHdfOo=
|
||||
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 h1:DaQjoWZhLNxjhIXedVg4/vFEtHkZhK4IjIwsWdyzBLg=
|
||||
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11/go.mod h1:JtbX20untAjUVjZs1ZBtq80f5rJWvwtQNRL6EnuYRnY=
|
||||
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 h1:aBSonas7vFcgTj9u96/bWGILGv1ZbUSTLiOzcI1ZT6c=
|
||||
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
@@ -74,8 +75,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible h1:1qp9iks+69h7IGLazAplzS9Ca14HAxuD5c0rbFdPGy4=
|
||||
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible/go.mod h1:+ZBN7PBoh5gG6/y0ZQ85vJDBe21WnfbRrQQwTfliJJI=
|
||||
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
|
||||
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/mssola/user_agent v0.5.2 h1:CZkTUahjL1+OcZ5zv3kZr8QiJ8jy2H08vZIEkBeRbxo=
|
||||
@@ -94,8 +95,6 @@ github.com/opentracing/opentracing-go v1.1.1-0.20190913142402-a7454ce5950e/go.mo
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/shirou/gopsutil v2.20.9+incompatible h1:msXs2frUV+O/JLva9EDLpuJ84PrFsdCTCQex8PUdtkQ=
|
||||
github.com/shirou/gopsutil v2.20.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shirou/gopsutil v3.21.5+incompatible h1:OloQyEerMi7JUrXiNzy8wQ5XN+baemxSl12QgIzt0jc=
|
||||
github.com/shirou/gopsutil v3.21.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
|
||||
@@ -134,7 +133,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7 h1:AeiKBIuRw3UomYXSbLy0Mc2dDLfdtbT/IVn4keq83P0=
|
||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q=
|
||||
@@ -155,10 +153,8 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 h1:DYfZAGf2WMFjMxbgTjaC+2HC7NkNAQs+6Q8b9WEB/F4=
|
||||
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa h1:ZYxPR6aca/uhfRJyaOAtflSHjJYiktO7QnJC5ut7iY4=
|
||||
golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -167,7 +163,6 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7s
|
||||
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -184,15 +179,14 @@ golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
google.golang.org/genproto v0.0.0-20210617175327-b9e0b3197ced h1:c5geK1iMU3cDKtFrCVQIcjR3W+JOZMuhIyICMCTbtus=
|
||||
google.golang.org/genproto v0.0.0-20210617175327-b9e0b3197ced/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24=
|
||||
@@ -201,7 +195,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
|
||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
|
||||
google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0=
|
||||
google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
|
||||
google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0=
|
||||
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
|
||||
@@ -213,7 +206,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
|
||||
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c=
|
||||
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
|
||||
|
||||
@@ -2,8 +2,11 @@ package apps
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -11,7 +14,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// App命令帮助
|
||||
// AppCmd App命令帮助
|
||||
type AppCmd struct {
|
||||
product string
|
||||
version string
|
||||
@@ -20,10 +23,14 @@ type AppCmd struct {
|
||||
appendStrings []string
|
||||
|
||||
directives []*Directive
|
||||
|
||||
sock *gosock.Sock
|
||||
}
|
||||
|
||||
func NewAppCmd() *AppCmd {
|
||||
return &AppCmd{}
|
||||
return &AppCmd{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
}
|
||||
}
|
||||
|
||||
type CommandHelpOption struct {
|
||||
@@ -31,25 +38,25 @@ type CommandHelpOption struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
// 产品
|
||||
// Product 产品
|
||||
func (this *AppCmd) Product(product string) *AppCmd {
|
||||
this.product = product
|
||||
return this
|
||||
}
|
||||
|
||||
// 版本
|
||||
// Version 版本
|
||||
func (this *AppCmd) Version(version string) *AppCmd {
|
||||
this.version = version
|
||||
return this
|
||||
}
|
||||
|
||||
// 使用方法
|
||||
// Usage 使用方法
|
||||
func (this *AppCmd) Usage(usage string) *AppCmd {
|
||||
this.usage = usage
|
||||
return this
|
||||
}
|
||||
|
||||
// 选项
|
||||
// Option 选项
|
||||
func (this *AppCmd) Option(code string, description string) *AppCmd {
|
||||
this.options = append(this.options, &CommandHelpOption{
|
||||
Code: code,
|
||||
@@ -58,13 +65,13 @@ func (this *AppCmd) Option(code string, description string) *AppCmd {
|
||||
return this
|
||||
}
|
||||
|
||||
// 附加内容
|
||||
// Append 附加内容
|
||||
func (this *AppCmd) Append(appendString string) *AppCmd {
|
||||
this.appendStrings = append(this.appendStrings, appendString)
|
||||
return this
|
||||
}
|
||||
|
||||
// 打印
|
||||
// Print 打印
|
||||
func (this *AppCmd) Print() {
|
||||
fmt.Println(this.product + " v" + this.version)
|
||||
|
||||
@@ -103,7 +110,7 @@ func (this *AppCmd) Print() {
|
||||
}
|
||||
}
|
||||
|
||||
// 添加指令
|
||||
// On 添加指令
|
||||
func (this *AppCmd) On(arg string, callback func()) {
|
||||
this.directives = append(this.directives, &Directive{
|
||||
Arg: arg,
|
||||
@@ -111,7 +118,7 @@ func (this *AppCmd) On(arg string, callback func()) {
|
||||
})
|
||||
}
|
||||
|
||||
// 运行
|
||||
// Run 运行
|
||||
func (this *AppCmd) Run(main func()) {
|
||||
// 获取参数
|
||||
args := os.Args[1:]
|
||||
@@ -161,7 +168,7 @@ func (this *AppCmd) Run(main func()) {
|
||||
|
||||
// 版本号
|
||||
func (this *AppCmd) runVersion() {
|
||||
fmt.Println(this.product+" v"+this.version, "(build: "+runtime.Version(), runtime.GOOS, runtime.GOARCH+")")
|
||||
fmt.Println(this.product+" v"+this.version, "(build: "+runtime.Version(), runtime.GOOS, runtime.GOARCH, teaconst.Tag+")")
|
||||
}
|
||||
|
||||
// 帮助
|
||||
@@ -171,9 +178,9 @@ func (this *AppCmd) runHelp() {
|
||||
|
||||
// 启动
|
||||
func (this *AppCmd) runStart() {
|
||||
proc := this.checkPid()
|
||||
if proc != nil {
|
||||
fmt.Println(this.product+" already started, pid:", proc.Pid)
|
||||
var pid = this.getPID()
|
||||
if pid > 0 {
|
||||
fmt.Println(this.product+" already started, pid:", pid)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -189,18 +196,15 @@ func (this *AppCmd) runStart() {
|
||||
|
||||
// 停止
|
||||
func (this *AppCmd) runStop() {
|
||||
proc := this.checkPid()
|
||||
if proc == nil {
|
||||
var pid = this.getPID()
|
||||
if pid == 0 {
|
||||
fmt.Println(this.product + " not started yet")
|
||||
return
|
||||
}
|
||||
|
||||
// 停止进程
|
||||
_ = proc.Kill()
|
||||
_, _ = this.sock.Send(&gosock.Command{Code: "stop"})
|
||||
|
||||
// 在Windows上经常不能及时释放资源
|
||||
_ = DeletePid(Tea.Root + "/bin/pid")
|
||||
fmt.Println(this.product+" stopped ok, pid:", proc.Pid)
|
||||
fmt.Println(this.product+" stopped ok, pid:", types.String(pid))
|
||||
}
|
||||
|
||||
// 重启
|
||||
@@ -212,15 +216,24 @@ func (this *AppCmd) runRestart() {
|
||||
|
||||
// 状态
|
||||
func (this *AppCmd) runStatus() {
|
||||
proc := this.checkPid()
|
||||
if proc == nil {
|
||||
var pid = this.getPID()
|
||||
if pid == 0 {
|
||||
fmt.Println(this.product + " not started yet")
|
||||
} else {
|
||||
fmt.Println(this.product + " is running, pid: " + fmt.Sprintf("%d", proc.Pid))
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(this.product + " is running, pid: " + types.String(pid))
|
||||
}
|
||||
|
||||
// 检查PID
|
||||
func (this *AppCmd) checkPid() *os.Process {
|
||||
return CheckPid(Tea.Root + "/bin/pid")
|
||||
// 获取当前的PID
|
||||
func (this *AppCmd) getPID() int {
|
||||
if !this.sock.IsListening() {
|
||||
return 0
|
||||
}
|
||||
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return maps.NewMap(reply.Params).GetInt("pid")
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
// +build !windows
|
||||
|
||||
package apps
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// lock file
|
||||
func LockFile(fp *os.File) error {
|
||||
return syscall.Flock(int(fp.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
}
|
||||
|
||||
func UnlockFile(fp *os.File) error {
|
||||
return syscall.Flock(int(fp.Fd()), syscall.LOCK_UN)
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// +build windows
|
||||
|
||||
package apps
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
// lock file
|
||||
func LockFile(fp *os.File) error {
|
||||
return errors.New("not implemented on windows")
|
||||
}
|
||||
|
||||
func UnlockFile(fp *os.File) error {
|
||||
return errors.New("not implemented on windows")
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package apps
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var pidFileList = []*os.File{}
|
||||
|
||||
// 检查Pid
|
||||
func CheckPid(path string) *os.Process {
|
||||
// windows上打开的文件是不能删除的
|
||||
if runtime.GOOS == "windows" {
|
||||
if os.Remove(path) == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
// 是否能取得Lock
|
||||
err = LockFile(file)
|
||||
if err == nil {
|
||||
_ = UnlockFile(file)
|
||||
return nil
|
||||
}
|
||||
|
||||
pidBytes, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
pid := types.Int(string(pidBytes))
|
||||
|
||||
if pid <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
proc, _ := os.FindProcess(pid)
|
||||
return proc
|
||||
}
|
||||
|
||||
// 写入Pid
|
||||
func WritePid() error {
|
||||
path := Tea.Root + "/bin/pid"
|
||||
fp, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_RDONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = fp.Close()
|
||||
})
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
err = LockFile(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
pidFileList = append(pidFileList, fp) // hold the file pointers
|
||||
|
||||
_, err = fp.WriteString(fmt.Sprintf("%d", os.Getpid()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 写入Ppid
|
||||
func WritePpid(path string) error {
|
||||
fp, err := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_RDONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
err = LockFile(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
pidFileList = append(pidFileList, fp) // hold the file pointers
|
||||
|
||||
_, err = fp.WriteString(fmt.Sprintf("%d", os.Getppid()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除Pid
|
||||
func DeletePid(path string) error {
|
||||
_, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, fp := range pidFileList {
|
||||
_ = UnlockFile(fp)
|
||||
_ = fp.Close()
|
||||
}
|
||||
return os.Remove(path)
|
||||
}
|
||||
@@ -54,7 +54,12 @@ func (this *FileList) Init() error {
|
||||
|
||||
this.itemsTableName = "cacheItems_v2"
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+this.dir+"/index.db?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
var dir = this.dir
|
||||
if dir == "/" {
|
||||
// 防止sqlite提示authority错误
|
||||
dir = ""
|
||||
}
|
||||
db, err := sql.Open("sqlite3", "file:"+dir+"/index.db?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -115,7 +120,7 @@ func (this *FileList) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
this.statStmt, err = this.db.Prepare(`SELECT COUNT(*), IFNULL(SUM(headerSize+bodySize+metaSize), 0), IFNULL(SUM(headerSize+bodySize), 0) FROM "` + this.itemsTableName + `" WHERE expiredAt>?`)
|
||||
this.statStmt, err = this.db.Prepare(`SELECT COUNT(*), IFNULL(SUM(headerSize+bodySize+metaSize), 0), IFNULL(SUM(headerSize+bodySize), 0) FROM "` + this.itemsTableName + `"`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -297,7 +302,7 @@ func (this *FileList) Stat(check func(hash string) bool) (*Stat, error) {
|
||||
}
|
||||
|
||||
// 这里不设置过期时间、不使用 check 函数,目的是让查询更快速一些
|
||||
row := this.statStmt.QueryRow(time.Now().Unix())
|
||||
row := this.statStmt.QueryRow()
|
||||
if row.Err() != nil {
|
||||
return nil, row.Err()
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
@@ -162,3 +163,25 @@ func (this *Manager) TotalMemorySize() int64 {
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// FindAllCachePaths 所有缓存路径
|
||||
func (this *Manager) FindAllCachePaths() []string {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var result = []string{}
|
||||
for _, policy := range this.policyMap {
|
||||
if policy.Type == serverconfigs.CachePolicyStorageFile {
|
||||
if policy.Options != nil {
|
||||
dir, ok := policy.Options["dir"]
|
||||
if ok {
|
||||
var dirString = types.String(dir)
|
||||
if len(dirString) > 0 {
|
||||
result = append(result, dirString)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
8
internal/const/build.go
Normal file
8
internal/const/build.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
// +build community
|
||||
|
||||
package teaconst
|
||||
|
||||
const BuildCommunity = true
|
||||
const BuildPlus = false
|
||||
const Tag = "community"
|
||||
8
internal/const/build_plus.go
Normal file
8
internal/const/build_plus.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
// +build plus
|
||||
|
||||
package teaconst
|
||||
|
||||
const BuildCommunity = false
|
||||
const BuildPlus = true
|
||||
const Tag = "plus"
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.2.4"
|
||||
Version = "0.2.7"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// IP名单
|
||||
// IPList IP名单
|
||||
type IPList struct {
|
||||
itemsMap map[int64]*IPItem // id => item
|
||||
ipMap map[uint64][]int64 // ip => itemIds
|
||||
@@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) {
|
||||
this.isAll = len(this.ipMap[0]) > 0
|
||||
}
|
||||
|
||||
// 判断是否包含某个IP
|
||||
// Contains 判断是否包含某个IP
|
||||
func (this *IPList) Contains(ip uint64) bool {
|
||||
this.locker.RLock()
|
||||
if this.isAll {
|
||||
@@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// 是否包含一组IP
|
||||
// ContainsIPStrings 是否包含一组IP
|
||||
func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
|
||||
if len(ipStrings) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package iplibrary
|
||||
|
||||
type LibraryInterface interface {
|
||||
// 加载数据库文件
|
||||
// Load 加载数据库文件
|
||||
Load(dbPath string) error
|
||||
|
||||
// 查询IP
|
||||
// Lookup 查询IP
|
||||
// 返回结果有可能为空
|
||||
Lookup(ip string) (*Result, error)
|
||||
|
||||
// 关闭数据库文件
|
||||
// Close 关闭数据库文件
|
||||
Close()
|
||||
}
|
||||
|
||||
121
internal/metrics/manager.go
Normal file
121
internal/metrics/manager.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SharedManager = NewManager()
|
||||
|
||||
type Manager struct {
|
||||
tasks map[int64]*Task // itemId => *Task
|
||||
categoryTasks map[string][]*Task // category => []*Task
|
||||
locker sync.RWMutex
|
||||
|
||||
hasHTTPMetrics bool
|
||||
hasTCPMetrics bool
|
||||
hasUDPMetrics bool
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
tasks: map[int64]*Task{},
|
||||
categoryTasks: map[string][]*Task{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Manager) Update(items []*serverconfigs.MetricItemConfig) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var newMap = map[int64]*serverconfigs.MetricItemConfig{}
|
||||
for _, item := range items {
|
||||
newMap[item.Id] = item
|
||||
}
|
||||
|
||||
// 停用以前的 或 修改现在的
|
||||
for itemId, task := range this.tasks {
|
||||
newItem, ok := newMap[itemId]
|
||||
if !ok || !newItem.IsOn { // 停用以前的
|
||||
remotelogs.Println("METRIC_MANAGER", "stop task '"+strconv.FormatInt(itemId, 10)+"'")
|
||||
err := task.Stop()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC_MANAGER", "stop task '"+strconv.FormatInt(itemId, 10)+"' failed: "+err.Error())
|
||||
}
|
||||
delete(this.tasks, itemId)
|
||||
} else { // 更新已存在的
|
||||
if newItem.Version != task.item.Version {
|
||||
remotelogs.Println("METRIC_MANAGER", "update task '"+strconv.FormatInt(itemId, 10)+"'")
|
||||
task.item = newItem
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 启动新的
|
||||
for _, newItem := range items {
|
||||
if !newItem.IsOn {
|
||||
continue
|
||||
}
|
||||
_, ok := this.tasks[newItem.Id]
|
||||
if !ok {
|
||||
remotelogs.Println("METRIC_MANAGER", "start task '"+strconv.FormatInt(newItem.Id, 10)+"'")
|
||||
task := NewTask(newItem)
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC_MANAGER", "initialized task failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC_MANAGER", "start task failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
this.tasks[newItem.Id] = task
|
||||
}
|
||||
}
|
||||
|
||||
// 按分类存放
|
||||
this.hasHTTPMetrics = false
|
||||
this.hasTCPMetrics = false
|
||||
this.hasUDPMetrics = false
|
||||
this.categoryTasks = map[string][]*Task{}
|
||||
for _, task := range this.tasks {
|
||||
tasks := this.categoryTasks[task.item.Category]
|
||||
tasks = append(tasks, task)
|
||||
this.categoryTasks[task.item.Category] = tasks
|
||||
|
||||
switch task.item.Category {
|
||||
case serverconfigs.MetricItemCategoryHTTP:
|
||||
this.hasHTTPMetrics = true
|
||||
case serverconfigs.MetricItemCategoryTCP:
|
||||
this.hasTCPMetrics = true
|
||||
case serverconfigs.MetricItemCategoryUDP:
|
||||
this.hasUDPMetrics = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加数据
|
||||
func (this *Manager) Add(obj MetricInterface) {
|
||||
this.locker.RLock()
|
||||
for _, task := range this.categoryTasks[obj.MetricCategory()] {
|
||||
task.Add(obj)
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
}
|
||||
|
||||
func (this *Manager) HasHTTPMetrics() bool {
|
||||
return this.hasHTTPMetrics
|
||||
}
|
||||
|
||||
func (this *Manager) HasTCPMetrics() bool {
|
||||
return this.hasTCPMetrics
|
||||
}
|
||||
|
||||
func (this *Manager) HasUDPMetrics() bool {
|
||||
return this.hasUDPMetrics
|
||||
}
|
||||
63
internal/metrics/manager_test.go
Normal file
63
internal/metrics/manager_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
var manager = NewManager()
|
||||
{
|
||||
manager.Update([]*serverconfigs.MetricItemConfig{})
|
||||
for _, task := range manager.tasks {
|
||||
t.Log(task.item.Id)
|
||||
}
|
||||
}
|
||||
{
|
||||
t.Log("====")
|
||||
manager.Update([]*serverconfigs.MetricItemConfig{
|
||||
{
|
||||
Id: 1,
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
},
|
||||
{
|
||||
Id: 3,
|
||||
},
|
||||
})
|
||||
for _, task := range manager.tasks {
|
||||
t.Log("task:", task.item.Id)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
t.Log("====")
|
||||
manager.Update([]*serverconfigs.MetricItemConfig{
|
||||
{
|
||||
Id: 1,
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
},
|
||||
})
|
||||
for _, task := range manager.tasks {
|
||||
t.Log("task:", task.item.Id)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
t.Log("====")
|
||||
manager.Update([]*serverconfigs.MetricItemConfig{
|
||||
{
|
||||
Id: 1,
|
||||
Version: 1,
|
||||
},
|
||||
})
|
||||
for _, task := range manager.tasks {
|
||||
t.Log("task:", task.item.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
17
internal/metrics/metric_interface.go
Normal file
17
internal/metrics/metric_interface.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics
|
||||
|
||||
type MetricInterface interface {
|
||||
// MetricKey 指标对象
|
||||
MetricKey(key string) string
|
||||
|
||||
// MetricValue 指标值
|
||||
MetricValue(value string) (result int64, ok bool)
|
||||
|
||||
// MetricServerId 服务ID
|
||||
MetricServerId() int64
|
||||
|
||||
// MetricCategory 指标分类
|
||||
MetricCategory() string
|
||||
}
|
||||
22
internal/metrics/stat.go
Normal file
22
internal/metrics/stat.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/cespare/xxhash"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Stat struct {
|
||||
ServerId int64
|
||||
Keys []string
|
||||
Hash string
|
||||
Value int64
|
||||
Time string
|
||||
|
||||
keysData []byte
|
||||
}
|
||||
|
||||
func (this *Stat) Sum(version int32, itemId int64) {
|
||||
this.Hash = strconv.FormatUint(xxhash.Sum64String(strconv.FormatInt(this.ServerId, 10)+"@"+string(this.keysData)+"@"+this.Time+"@"+strconv.Itoa(int(version))+"@"+strconv.FormatInt(itemId, 10)), 10)
|
||||
}
|
||||
493
internal/metrics/task.go
Normal file
493
internal/metrics/task.go
Normal file
@@ -0,0 +1,493 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task 单个指标任务
|
||||
// 数据库存储:
|
||||
// data/
|
||||
// metric.$ID.db
|
||||
// stats
|
||||
// id, keys, value, time, serverId, hash
|
||||
// 原理:
|
||||
// 添加或者有变更时 isUploaded = false
|
||||
// 上传时检查 isUploaded 状态
|
||||
// 只上传每个服务中排序最前面的 N 个数据
|
||||
type Task struct {
|
||||
item *serverconfigs.MetricItemConfig
|
||||
isLoaded bool
|
||||
|
||||
db *sql.DB
|
||||
statTableName string
|
||||
statsChan chan *Stat
|
||||
isStopped bool
|
||||
|
||||
cleanTicker *utils.Ticker
|
||||
uploadTicker *utils.Ticker
|
||||
|
||||
cleanVersion int32
|
||||
|
||||
insertStatStmt *sql.Stmt
|
||||
deleteByVersionStmt *sql.Stmt
|
||||
deleteByExpiresTimeStmt *sql.Stmt
|
||||
selectTopStmt *sql.Stmt
|
||||
sumStmt *sql.Stmt
|
||||
|
||||
serverIdMap map[int64]bool // 所有的服务Ids
|
||||
timeMap map[string]bool // time => bool
|
||||
serverIdMapLocker sync.Mutex
|
||||
}
|
||||
|
||||
// NewTask 获取新任务
|
||||
func NewTask(item *serverconfigs.MetricItemConfig) *Task {
|
||||
return &Task{
|
||||
item: item,
|
||||
statsChan: make(chan *Stat, 40960),
|
||||
serverIdMap: map[int64]bool{},
|
||||
timeMap: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func (this *Task) Init() error {
|
||||
this.statTableName = "stats"
|
||||
|
||||
// 检查目录是否存在
|
||||
var dir = Tea.Root + "/data"
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("METRIC", "create data dir '"+dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+dir+"/metric."+strconv.FormatInt(this.item.Id, 10)+".db?cache=shared&mode=rwc&_journal_mode=WAL")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
this.db = db
|
||||
|
||||
//创建统计表
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.statTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"hash" varchar(32),
|
||||
"keys" varchar(1024),
|
||||
"value" real DEFAULT 0,
|
||||
"time" varchar(32),
|
||||
"serverId" integer DEFAULT 0,
|
||||
"version" integer DEFAULT 0,
|
||||
"isUploaded" integer DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "serverId"
|
||||
ON "` + this.statTableName + `" (
|
||||
"serverId" ASC,
|
||||
"version" ASC
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "hash"
|
||||
ON "` + this.statTableName + `" (
|
||||
"hash" ASC
|
||||
);`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// insert stat stmt
|
||||
this.insertStatStmt, err = db.Prepare(`INSERT INTO "stats" ("serverId", "hash", "keys", "value", "time", "version", "isUploaded") VALUES (?, ?, ?, ?, ?, ?, 0) ON CONFLICT("hash") DO UPDATE SET "value"="value"+?, "isUploaded"=0`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete by version
|
||||
this.deleteByVersionStmt, err = db.Prepare(`DELETE FROM "` + this.statTableName + `" WHERE "version"<?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete by expires time
|
||||
this.deleteByExpiresTimeStmt, err = db.Prepare(`DELETE FROM "` + this.statTableName + `" WHERE "time"<?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// select topN stmt
|
||||
this.selectTopStmt, err = db.Prepare(`SELECT "id", "hash", "keys", "value", "isUploaded" FROM "` + this.statTableName + `" WHERE "serverId"=? AND "version"=? AND time=? ORDER BY "value" DESC LIMIT 20`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sum stmt
|
||||
this.sumStmt, err = db.Prepare(`SELECT COUNT(*), IFNULL(SUM(value), 0) FROM "` + this.statTableName + `" WHERE "serverId"=? AND "version"=? AND time=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 所有的服务IDs
|
||||
err = this.loadServerIdMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.isLoaded = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动任务
|
||||
func (this *Task) Start() error {
|
||||
// 读取数据
|
||||
go func() {
|
||||
for stat := range this.statsChan {
|
||||
if stat == nil {
|
||||
return
|
||||
}
|
||||
err := this.InsertStat(stat)
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC", "insert stat failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 清理
|
||||
this.cleanTicker = utils.NewTicker(24 * time.Hour)
|
||||
go func() {
|
||||
for this.cleanTicker.Next() {
|
||||
err := this.CleanExpired()
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC", "clean expired stats failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 上传
|
||||
this.uploadTicker = utils.NewTicker(this.item.UploadDuration())
|
||||
go func() {
|
||||
for this.uploadTicker.Next() {
|
||||
err := this.Upload(1 * time.Second)
|
||||
if err != nil {
|
||||
remotelogs.Error("METRIC", "upload stats failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add 添加数据
|
||||
func (this *Task) Add(obj MetricInterface) {
|
||||
if this.isStopped || !this.isLoaded {
|
||||
return
|
||||
}
|
||||
|
||||
var keys = []string{}
|
||||
for _, key := range this.item.Keys {
|
||||
k := obj.MetricKey(key)
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
v, ok := obj.MetricValue(this.item.Value)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var stat = &Stat{
|
||||
ServerId: obj.MetricServerId(),
|
||||
Keys: keys,
|
||||
Value: v,
|
||||
Time: this.item.CurrentTime(),
|
||||
}
|
||||
|
||||
select {
|
||||
case this.statsChan <- stat:
|
||||
default:
|
||||
// 丢弃
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止任务
|
||||
func (this *Task) Stop() error {
|
||||
this.isStopped = true
|
||||
|
||||
if this.cleanTicker != nil {
|
||||
this.cleanTicker.Stop()
|
||||
}
|
||||
if this.uploadTicker != nil {
|
||||
this.uploadTicker.Stop()
|
||||
}
|
||||
|
||||
_ = this.insertStatStmt.Close()
|
||||
_ = this.deleteByVersionStmt.Close()
|
||||
_ = this.deleteByExpiresTimeStmt.Close()
|
||||
_ = this.selectTopStmt.Close()
|
||||
_ = this.sumStmt.Close()
|
||||
|
||||
if this.statsChan != nil {
|
||||
go func() {
|
||||
// 延时关闭,防止关闭时写入
|
||||
time.Sleep(5 * time.Second)
|
||||
close(this.statsChan)
|
||||
}()
|
||||
}
|
||||
|
||||
if this.db != nil {
|
||||
_ = this.db.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertStat 写入数据
|
||||
func (this *Task) InsertStat(stat *Stat) error {
|
||||
if this.isStopped {
|
||||
return nil
|
||||
}
|
||||
if stat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.serverIdMapLocker.Lock()
|
||||
this.serverIdMap[stat.ServerId] = true
|
||||
this.timeMap[stat.Time] = true
|
||||
this.serverIdMapLocker.Unlock()
|
||||
|
||||
keyData, err := json.Marshal(stat.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stat.keysData = keyData
|
||||
stat.Sum(this.item.Version, this.item.Id)
|
||||
|
||||
_, err = this.insertStatStmt.Exec(stat.ServerId, stat.Hash, stat.keysData, stat.Value, stat.Time, this.item.Version, stat.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanExpired 清理数据
|
||||
func (this *Task) CleanExpired() error {
|
||||
if this.isStopped {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清除低版本数据
|
||||
if this.cleanVersion < this.item.Version {
|
||||
_, err := this.deleteByVersionStmt.Exec(this.item.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.cleanVersion = this.item.Version
|
||||
}
|
||||
|
||||
// 清除过期的数据
|
||||
_, err := this.deleteByExpiresTimeStmt.Exec(this.item.LocalExpiresTime())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upload 上传数据
|
||||
func (this *Task) Upload(pauseDuration time.Duration) error {
|
||||
if this.isStopped {
|
||||
return nil
|
||||
}
|
||||
|
||||
this.serverIdMapLocker.Lock()
|
||||
|
||||
// 服务IDs
|
||||
var serverIds []int64
|
||||
for serverId := range this.serverIdMap {
|
||||
serverIds = append(serverIds, serverId)
|
||||
}
|
||||
this.serverIdMap = map[int64]bool{} // 清空数据
|
||||
|
||||
// 时间
|
||||
var times = []string{}
|
||||
for t := range this.timeMap {
|
||||
times = append(times, t)
|
||||
}
|
||||
this.timeMap = map[string]bool{} // 清空数据
|
||||
|
||||
this.serverIdMapLocker.Unlock()
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, serverId := range serverIds {
|
||||
for _, currentTime := range times {
|
||||
idStrings, err := func(serverId int64, currentTime string) (ids []string, err error) {
|
||||
rows, err := this.selectTopStmt.Query(serverId, this.item.Version, currentTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var isClosed bool
|
||||
defer func() {
|
||||
if isClosed {
|
||||
return
|
||||
}
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var pbStats []*pb.UploadingMetricStat
|
||||
for rows.Next() {
|
||||
var pbStat = &pb.UploadingMetricStat{}
|
||||
// "id", "hash", "keys", "value", "isUploaded"
|
||||
var isUploaded int
|
||||
var keysData []byte
|
||||
err = rows.Scan(&pbStat.Id, &pbStat.Hash, &keysData, &pbStat.Value, &isUploaded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO 先不判断是否已经上传,需要改造API进行配合
|
||||
/**if isUploaded == 1 {
|
||||
continue
|
||||
}**/
|
||||
if len(keysData) > 0 {
|
||||
err = json.Unmarshal(keysData, &pbStat.Keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
pbStats = append(pbStats, pbStat)
|
||||
ids = append(ids, strconv.FormatInt(pbStat.Id, 10))
|
||||
}
|
||||
|
||||
// 提前关闭
|
||||
_ = rows.Close()
|
||||
isClosed = true
|
||||
|
||||
// 上传
|
||||
if len(pbStats) > 0 {
|
||||
// 计算总和
|
||||
count, total, err := this.sum(serverId, currentTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = rpcClient.MetricStatRPC().UploadMetricStats(rpcClient.Context(), &pb.UploadMetricStatsRequest{
|
||||
MetricStats: pbStats,
|
||||
Time: currentTime,
|
||||
ServerId: serverId,
|
||||
ItemId: this.item.Id,
|
||||
Version: this.item.Version,
|
||||
Count: count,
|
||||
Total: float32(total),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}(serverId, currentTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(idStrings) > 0 {
|
||||
// 设置为已上传
|
||||
_, err = this.db.Exec(`UPDATE "` + this.statTableName + `" SET isUploaded=1 WHERE id IN (` + strings.Join(idStrings, ",") + `)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 休息一下,防止短时间内上传数据过多
|
||||
if pauseDuration > 0 {
|
||||
time.Sleep(pauseDuration)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 加载服务ID
|
||||
func (this *Task) loadServerIdMap() error {
|
||||
{
|
||||
rows, err := this.db.Query(`SELECT DISTINCT "serverId" FROM `+this.statTableName+" WHERE version=?", this.item.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var serverId int64
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&serverId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.serverIdMapLocker.Lock()
|
||||
this.serverIdMap[serverId] = true
|
||||
this.serverIdMapLocker.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
rows, err := this.db.Query(`SELECT DISTINCT "time" FROM `+this.statTableName+" WHERE version=?", this.item.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var timeString string
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&timeString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.serverIdMapLocker.Lock()
|
||||
this.timeMap[timeString] = true
|
||||
this.serverIdMapLocker.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 计算数量和综合
|
||||
func (this *Task) sum(serverId int64, time string) (count int64, total float64, err error) {
|
||||
rows, err := this.sumStmt.Query(serverId, this.item.Version, time)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
if rows.Next() {
|
||||
err = rows.Scan(&count, &total)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
210
internal/metrics/task_test.go
Normal file
210
internal/metrics/task_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package metrics_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testObj struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (this *testObj) MetricKey(key string) string {
|
||||
return this.ip
|
||||
}
|
||||
|
||||
func (this *testObj) MetricValue(value string) (int64, bool) {
|
||||
return 1, true
|
||||
}
|
||||
|
||||
func (this *testObj) MetricServerId() int64 {
|
||||
return int64(rands.Int(1, 100))
|
||||
}
|
||||
|
||||
func (this *testObj) MetricCategory() string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
func TestTask_Init(t *testing.T) {
|
||||
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 0,
|
||||
PeriodUnit: "",
|
||||
Keys: nil,
|
||||
Value: "",
|
||||
})
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTask_Add(t *testing.T) {
|
||||
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 1,
|
||||
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
|
||||
Keys: []string{"${remoteAddr}"},
|
||||
Value: "${countRequest}",
|
||||
})
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
|
||||
task.Add(&testObj{ip: "127.0.0.2"})
|
||||
time.Sleep(1 * time.Second) // waiting for inserting
|
||||
}
|
||||
|
||||
func TestTask_Add_Many(t *testing.T) {
|
||||
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 1,
|
||||
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
|
||||
Keys: []string{"${remoteAddr}"},
|
||||
Value: "${countRequest}",
|
||||
Version: 1,
|
||||
})
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
|
||||
for i := 0; i < 4_000_000; i++ {
|
||||
task.Add(&testObj{
|
||||
ip: fmt.Sprintf("%d.%d.%d.%d", rands.Int(0, 255), rands.Int(0, 255), rands.Int(0, 255), rands.Int(0, 255)),
|
||||
})
|
||||
if i%10000 == 0 {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTask_InsertStat(t *testing.T) {
|
||||
var item = &serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 1,
|
||||
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
|
||||
Keys: []string{"${remoteAddr}"},
|
||||
Value: "${countRequest}",
|
||||
Version: 1,
|
||||
}
|
||||
var task = metrics.NewTask(item)
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
|
||||
err = task.InsertStat(&metrics.Stat{
|
||||
ServerId: 1,
|
||||
Keys: []string{"127.0.0.1"},
|
||||
Hash: "",
|
||||
Value: 1,
|
||||
Time: item.CurrentTime(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTask_CleanExpired(t *testing.T) {
|
||||
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 1,
|
||||
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
|
||||
Keys: []string{"${remoteAddr}"},
|
||||
Value: "${countRequest}",
|
||||
Version: 1,
|
||||
})
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
|
||||
err = task.CleanExpired()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTask_Upload(t *testing.T) {
|
||||
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
|
||||
Id: 1,
|
||||
IsOn: false,
|
||||
Category: "",
|
||||
Period: 1,
|
||||
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
|
||||
Keys: []string{"${remoteAddr}"},
|
||||
Value: "${countRequest}",
|
||||
Version: 1,
|
||||
})
|
||||
err := task.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = task.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = task.Stop()
|
||||
}()
|
||||
|
||||
err = task.Upload(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
@@ -72,7 +72,8 @@ func (this *ValueQueue) Loop() error {
|
||||
CreatedAt: value.CreatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
remotelogs.Error("MONITOR", err.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -66,12 +67,16 @@ type HTTPRequest struct {
|
||||
cacheRef *serverconfigs.HTTPCacheRef // 缓存设置
|
||||
cacheKey string // 缓存使用的Key
|
||||
isCached bool // 是否已经被缓存
|
||||
isAttack bool // 是否是攻击请求
|
||||
bodyData []byte // 读取的Body内容
|
||||
|
||||
// WAF相关
|
||||
firewallPolicyId int64
|
||||
firewallRuleGroupId int64
|
||||
firewallRuleSetId int64
|
||||
firewallRuleId int64
|
||||
firewallActions []string
|
||||
tags []string
|
||||
|
||||
logAttrs map[string]string
|
||||
|
||||
@@ -139,9 +144,10 @@ func (this *HTTPRequest) Do() {
|
||||
|
||||
// 自动跳转到HTTPS
|
||||
if this.IsHTTP && this.web.RedirectToHttps != nil && this.web.RedirectToHttps.IsOn {
|
||||
this.doRedirectToHTTPS(this.web.RedirectToHttps)
|
||||
this.doEnd()
|
||||
return
|
||||
if this.doRedirectToHTTPS(this.web.RedirectToHttps) {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Gzip
|
||||
@@ -242,11 +248,20 @@ func (this *HTTPRequest) doEnd() {
|
||||
// TODO 增加是否开启开关
|
||||
if this.Server != nil {
|
||||
if this.isCached {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1)
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0)
|
||||
} else {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.writer.sentBodyBytes, 0, 1, 0)
|
||||
if this.isAttack {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes)
|
||||
} else {
|
||||
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 指标
|
||||
if metrics.SharedManager.HasHTTPMetrics() {
|
||||
this.doMetricsResponse()
|
||||
}
|
||||
}
|
||||
|
||||
// RawURI 原始的请求URI
|
||||
@@ -500,6 +515,12 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
return this.requestRemoteUser()
|
||||
case "requestURI", "requestUri":
|
||||
return this.rawURI
|
||||
case "requestURL":
|
||||
var scheme = "http"
|
||||
if this.IsHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
return scheme + "://" + this.Host + this.rawURI
|
||||
case "requestPath":
|
||||
return this.requestPath()
|
||||
case "requestPathExtension":
|
||||
@@ -549,6 +570,12 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
return this.Host
|
||||
case "referer":
|
||||
return this.RawReq.Referer()
|
||||
case "referer.host":
|
||||
u, err := url.Parse(this.RawReq.Referer())
|
||||
if err == nil {
|
||||
return u.Host
|
||||
}
|
||||
return ""
|
||||
case "userAgent":
|
||||
return this.RawReq.UserAgent()
|
||||
case "contentType":
|
||||
@@ -1186,5 +1213,10 @@ func (this *HTTPRequest) canIgnore(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HTTP内部错误
|
||||
if strings.HasPrefix(err.Error(), "http:") || strings.HasPrefix(err.Error(), "http2:") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// 读取缓存
|
||||
func (this *HTTPRequest) doCacheRead() (shouldStop bool) {
|
||||
cachePolicy := sharedNodeConfig.HTTPCachePolicy
|
||||
cachePolicy := this.Server.HTTPCachePolicy
|
||||
if cachePolicy == nil || !cachePolicy.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,6 +128,8 @@ func (this *HTTPRequest) log() {
|
||||
FirewallRuleGroupId: this.firewallRuleGroupId,
|
||||
FirewallRuleSetId: this.firewallRuleSetId,
|
||||
FirewallRuleId: this.firewallRuleId,
|
||||
FirewallActions: this.firewallActions,
|
||||
Tags: this.tags,
|
||||
|
||||
Attrs: this.logAttrs,
|
||||
}
|
||||
|
||||
58
internal/nodes/http_request_metrics.go
Normal file
58
internal/nodes/http_request_metrics.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
)
|
||||
|
||||
// 指标统计 - 响应
|
||||
// 只需要在结束时调用指标进行统计
|
||||
func (this *HTTPRequest) doMetricsResponse() {
|
||||
metrics.SharedManager.Add(this)
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricKey(key string) string {
|
||||
return this.Format(key)
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
|
||||
// TODO 需要忽略健康检查的请求,但是同时也要防止攻击者模拟健康检查
|
||||
switch value {
|
||||
case "${countRequest}":
|
||||
return 1, true
|
||||
case "${countTrafficOut}":
|
||||
// 这里不包括Header长度
|
||||
return this.writer.SentBodyBytes(), true
|
||||
case "${countTrafficIn}":
|
||||
var hl int64 = 0 // header length
|
||||
for k, values := range this.RawReq.Header {
|
||||
for _, v := range values {
|
||||
hl += int64(len(k) + len(v) + 2 /** k: v **/)
|
||||
}
|
||||
}
|
||||
return this.RawReq.ContentLength + hl, true
|
||||
case "${countConnection}":
|
||||
metricNewConnMapLocker.Lock()
|
||||
_, ok := metricNewConnMap[this.RawReq.RemoteAddr]
|
||||
if ok {
|
||||
delete(metricNewConnMap, this.RawReq.RemoteAddr)
|
||||
}
|
||||
metricNewConnMapLocker.Unlock()
|
||||
if ok {
|
||||
return 1, true
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricServerId() int64 {
|
||||
return this.Server.Id
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricCategory() string {
|
||||
return serverconfigs.MetricItemCategoryHTTP
|
||||
}
|
||||
@@ -7,9 +7,14 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) {
|
||||
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) (shouldBreak bool) {
|
||||
host := this.RawReq.Host
|
||||
|
||||
// 检查域名是否匹配
|
||||
if !redirectToHTTPSConfig.MatchDomain(host) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(redirectToHTTPSConfig.Host) > 0 {
|
||||
if redirectToHTTPSConfig.Port > 0 && redirectToHTTPSConfig.Port != 443 {
|
||||
host = redirectToHTTPSConfig.Host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
@@ -38,4 +43,6 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
|
||||
|
||||
newURL := "https://" + host + this.RawReq.RequestURI
|
||||
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -160,6 +160,9 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
httpErr, ok := err.(*url.Error)
|
||||
if !ok || httpErr.Err != context.Canceled {
|
||||
// TODO 如果超过最大失败次数,则下线
|
||||
SharedOriginStateManager.Fail(origin, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
this.write502(err)
|
||||
remotelogs.Warn("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+"': "+err.Error())
|
||||
@@ -183,6 +186,11 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
}
|
||||
return
|
||||
}
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
// WAF对出站进行检查
|
||||
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn {
|
||||
|
||||
@@ -7,6 +7,8 @@ func (this *HTTPRequest) doStat() {
|
||||
if this.Server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 内置的统计
|
||||
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.Server.Id, this.requestRemoteAddr())
|
||||
stats.SharedHTTPRequestStatManager.AddUserAgent(this.Server.Id, this.requestHeader("User-Agent"))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -25,8 +28,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if sharedNodeConfig.HTTPFirewallPolicy != nil {
|
||||
blocked, breakChecking := this.checkWAFRequest(sharedNodeConfig.HTTPFirewallPolicy)
|
||||
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn {
|
||||
blocked, breakChecking := this.checkWAFRequest(this.Server.HTTPFirewallPolicy)
|
||||
if blocked {
|
||||
return true
|
||||
}
|
||||
@@ -152,23 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
|
||||
|
||||
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
|
||||
switch action.Code() {
|
||||
case waf.ActionTag:
|
||||
this.tags = action.(*waf.TagAction).Tags
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if ruleSet.Action != waf.ActionAllow {
|
||||
if ruleSet.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
|
||||
if ruleSet.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
}
|
||||
|
||||
this.logAttrs["waf.action"] = ruleSet.Action
|
||||
this.firewallActions = ruleSet.ActionCodes()
|
||||
}
|
||||
|
||||
return !goNext, false
|
||||
@@ -185,8 +201,8 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if sharedNodeConfig.HTTPFirewallPolicy != nil {
|
||||
blocked := this.checkWAFResponse(sharedNodeConfig.HTTPFirewallPolicy, resp)
|
||||
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn {
|
||||
blocked := this.checkWAFResponse(this.Server.HTTPFirewallPolicy, resp)
|
||||
if blocked {
|
||||
return true
|
||||
}
|
||||
@@ -204,24 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
return
|
||||
}
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
|
||||
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
|
||||
switch action.Code() {
|
||||
case waf.ActionTag:
|
||||
this.tags = action.(*waf.TagAction).Tags
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if ruleSet.Action != waf.ActionAllow {
|
||||
if ruleSet.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
|
||||
if ruleSet.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
}
|
||||
|
||||
this.logAttrs["waf.action"] = ruleSet.Action
|
||||
this.firewallActions = ruleSet.ActionCodes()
|
||||
}
|
||||
|
||||
return !goNext
|
||||
}
|
||||
|
||||
// WAFRaw 原始请求
|
||||
func (this *HTTPRequest) WAFRaw() *http.Request {
|
||||
return this.RawReq
|
||||
}
|
||||
|
||||
// WAFRemoteIP 客户端IP
|
||||
func (this *HTTPRequest) WAFRemoteIP() string {
|
||||
return this.requestRemoteAddr()
|
||||
}
|
||||
|
||||
// WAFGetCacheBody 获取缓存中的Body
|
||||
func (this *HTTPRequest) WAFGetCacheBody() []byte {
|
||||
return this.bodyData
|
||||
}
|
||||
|
||||
// WAFSetCacheBody 设置Body
|
||||
func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
|
||||
this.bodyData = body
|
||||
}
|
||||
|
||||
// WAFReadBody 读取Body
|
||||
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
if this.RawReq.ContentLength > 0 {
|
||||
data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WAFRestoreBody 恢复Body
|
||||
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
rawReader := bytes.NewBuffer(data)
|
||||
buf := make([]byte, 1024)
|
||||
_, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf)
|
||||
this.RawReq.Body = ioutil.NopCloser(rawReader)
|
||||
}
|
||||
}
|
||||
|
||||
// WAFServerId 服务ID
|
||||
func (this *HTTPRequest) WAFServerId() int64 {
|
||||
return this.Server.Id
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -223,17 +224,29 @@ func (this *HTTPWriter) Close() {
|
||||
// cache writer
|
||||
if this.cacheWriter != nil {
|
||||
if this.isOk {
|
||||
err := this.cacheWriter.Close()
|
||||
if err == nil {
|
||||
this.cacheStorage.AddToList(&caches.Item{
|
||||
Type: this.cacheWriter.ItemType(),
|
||||
Key: this.cacheWriter.Key(),
|
||||
ExpiredAt: this.cacheWriter.ExpiredAt(),
|
||||
HeaderSize: this.cacheWriter.HeaderSize(),
|
||||
BodySize: this.cacheWriter.BodySize(),
|
||||
Host: this.req.Host,
|
||||
ServerId: this.req.Server.Id,
|
||||
})
|
||||
// 对比Content-Length
|
||||
contentLengthString := this.Header().Get("Content-Length")
|
||||
if len(contentLengthString) > 0 {
|
||||
contentLength := types.Int64(contentLengthString)
|
||||
if contentLength != this.cacheWriter.BodySize() {
|
||||
this.isOk = false
|
||||
_ = this.cacheWriter.Discard()
|
||||
}
|
||||
}
|
||||
|
||||
if this.isOk {
|
||||
err := this.cacheWriter.Close()
|
||||
if err == nil {
|
||||
this.cacheStorage.AddToList(&caches.Item{
|
||||
Type: this.cacheWriter.ItemType(),
|
||||
Key: this.cacheWriter.Key(),
|
||||
ExpiredAt: this.cacheWriter.ExpiredAt(),
|
||||
HeaderSize: this.cacheWriter.HeaderSize(),
|
||||
BodySize: this.cacheWriter.BodySize(),
|
||||
Host: this.req.Host,
|
||||
ServerId: this.req.Server.Id,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_ = this.cacheWriter.Discard()
|
||||
@@ -329,7 +342,7 @@ func (this *HTTPWriter) prepareCache(size int64) {
|
||||
return
|
||||
}
|
||||
|
||||
cachePolicy := sharedNodeConfig.HTTPCachePolicy
|
||||
cachePolicy := this.req.Server.HTTPCachePolicy
|
||||
if cachePolicy == nil || !cachePolicy.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
http2 "golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,11 +9,14 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var httpErrorLogger = log.New(io.Discard, "", 0)
|
||||
var metricNewConnMap = map[string]bool{} // remoteAddr => bool
|
||||
var metricNewConnMapLocker = &sync.Mutex{}
|
||||
|
||||
type HTTPListener struct {
|
||||
BaseListener
|
||||
@@ -39,15 +42,27 @@ func (this *HTTPListener) Serve() error {
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.addr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
|
||||
ReadHeaderTimeout: 2 * time.Second, // TODO 改成可以配置
|
||||
IdleTimeout: 2 * time.Minute, // TODO 改成可以配置
|
||||
ErrorLog: httpErrorLogger,
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
|
||||
// 为指标存储连接信息
|
||||
if sharedNodeConfig.HasHTTPConnectionMetrics() {
|
||||
metricNewConnMapLocker.Lock()
|
||||
metricNewConnMap[conn.RemoteAddr().String()] = true
|
||||
metricNewConnMapLocker.Unlock()
|
||||
}
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
|
||||
// 移除指标存储连接信息
|
||||
metricNewConnMapLocker.Lock()
|
||||
delete(metricNewConnMap, conn.RemoteAddr().String())
|
||||
metricNewConnMapLocker.Unlock()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, int64(n), 0, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0)
|
||||
}
|
||||
if err != nil {
|
||||
closer()
|
||||
|
||||
@@ -164,7 +164,7 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
stats.SharedTrafficStatManager.Add(serverId, int64(n), 0, 0, 0)
|
||||
stats.SharedTrafficStatManager.Add(serverId, "", int64(n), 0, 0, 0, 0, 0)
|
||||
}
|
||||
if err != nil {
|
||||
conn.isOk = false
|
||||
|
||||
@@ -6,12 +6,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/apps"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
@@ -20,14 +20,13 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -39,10 +38,13 @@ var DaemonPid = 0
|
||||
// Node 节点
|
||||
type Node struct {
|
||||
isLoaded bool
|
||||
sock *gosock.Sock
|
||||
}
|
||||
|
||||
func NewNode() *Node {
|
||||
return &Node{}
|
||||
return &Node{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
}
|
||||
}
|
||||
|
||||
// Test 检查配置
|
||||
@@ -72,9 +74,6 @@ func (this *Node) Start() {
|
||||
// 启动事件
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 处理信号
|
||||
this.listenSignals()
|
||||
|
||||
// 本地Sock
|
||||
err := this.listenSock()
|
||||
if err != nil {
|
||||
@@ -151,24 +150,16 @@ func (this *Node) Start() {
|
||||
return
|
||||
}
|
||||
|
||||
// 写入PID
|
||||
err = apps.WritePid()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "write pid failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// hold住进程
|
||||
select {}
|
||||
}
|
||||
|
||||
// Daemon 实现守护进程
|
||||
func (this *Node) Daemon() {
|
||||
path := os.TempDir() + "/edge-node.sock"
|
||||
isDebug := lists.ContainsString(os.Args, "debug")
|
||||
isDebug = true
|
||||
for {
|
||||
conn, err := net.DialTimeout("unix", path, 1*time.Second)
|
||||
conn, err := this.sock.Dial()
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]starting ...")
|
||||
@@ -228,32 +219,6 @@ func (this *Node) InstallSystemService() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理信号
|
||||
func (this *Node) listenSignals() {
|
||||
signals := make(chan os.Signal)
|
||||
signal.Notify(signals, syscall.SIGQUIT)
|
||||
go func() {
|
||||
for s := range signals {
|
||||
switch s {
|
||||
case syscall.SIGQUIT:
|
||||
events.Notify(events.EventQuit)
|
||||
|
||||
// 监控连接数,如果连接数为0,则退出进程
|
||||
go func() {
|
||||
for {
|
||||
countActiveConnections := sharedListenerManager.TotalActiveConnections()
|
||||
if countActiveConnections <= 0 {
|
||||
os.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 循环
|
||||
func (this *Node) loop() error {
|
||||
// 检查api.yaml是否存在
|
||||
@@ -382,12 +347,12 @@ func (this *Node) syncConfig() error {
|
||||
} else {
|
||||
remotelogs.Println("NODE", "loading config ...")
|
||||
}
|
||||
|
||||
|
||||
nodeconfigs.ResetNodeConfig(nodeConfig)
|
||||
caches.SharedManager.MaxDiskCapacity = nodeConfig.MaxCacheDiskCapacity
|
||||
caches.SharedManager.MaxMemoryCapacity = nodeConfig.MaxCacheMemoryCapacity
|
||||
if nodeConfig.HTTPCachePolicy != nil {
|
||||
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{nodeConfig.HTTPCachePolicy})
|
||||
if len(nodeConfig.HTTPCachePolicies) > 0 {
|
||||
caches.SharedManager.UpdatePolicies(nodeConfig.HTTPCachePolicies)
|
||||
} else {
|
||||
caches.SharedManager.UpdatePolicies([]*serverconfigs.HTTPCachePolicy{})
|
||||
}
|
||||
@@ -396,6 +361,8 @@ func (this *Node) syncConfig() error {
|
||||
iplibrary.SharedActionManager.UpdateActions(nodeConfig.FirewallActions)
|
||||
sharedNodeConfig = nodeConfig
|
||||
|
||||
metrics.SharedManager.Update(nodeConfig.MetricItems)
|
||||
|
||||
// 发送事件
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
@@ -490,37 +457,63 @@ func (this *Node) checkClusterConfig() error {
|
||||
|
||||
// 监听本地sock
|
||||
func (this *Node) listenSock() error {
|
||||
path := os.TempDir() + "/edge-node.sock"
|
||||
|
||||
// 检查是否已经存在
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
_ = os.Remove(path)
|
||||
// 检查是否在运行
|
||||
if this.sock.IsListening() {
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err == nil {
|
||||
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
return errors.New("error: the process is already running")
|
||||
}
|
||||
}
|
||||
|
||||
// 新的监听任务
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
remotelogs.Println("NODE", "quit unix sock")
|
||||
_ = listener.Close()
|
||||
})
|
||||
|
||||
// 启动监听
|
||||
go func() {
|
||||
for {
|
||||
_, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
this.sock.OnCommand(func(cmd *gosock.Command) {
|
||||
switch cmd.Code {
|
||||
case "pid":
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "pid",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
},
|
||||
})
|
||||
case "stop":
|
||||
_ = cmd.ReplyOk()
|
||||
|
||||
// 退出主进程
|
||||
events.Notify(events.EventQuit)
|
||||
os.Exit(0)
|
||||
case "quit":
|
||||
_ = cmd.ReplyOk()
|
||||
_ = this.sock.Close()
|
||||
|
||||
events.Notify(events.EventQuit)
|
||||
|
||||
// 监控连接数,如果连接数为0,则退出进程
|
||||
go func() {
|
||||
for {
|
||||
countActiveConnections := sharedListenerManager.TotalActiveConnections()
|
||||
if countActiveConnections <= 0 {
|
||||
os.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
})
|
||||
|
||||
err := this.sock.Listen()
|
||||
if err != nil {
|
||||
logs.Println("NODE", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
logs.Println("NODE", "quit unix sock")
|
||||
_ = this.sock.Close()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"github.com/shirou/gopsutil/disk"
|
||||
"golang.org/x/sys/unix"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -67,6 +68,8 @@ func (this *NodeStatusExecutor) update() {
|
||||
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
|
||||
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()
|
||||
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
|
||||
status.TrafficInBytes = inTrafficBytes
|
||||
status.TrafficOutBytes = outTrafficBytes
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemConnections, maps.Map{
|
||||
@@ -80,6 +83,7 @@ func (this *NodeStatusExecutor) update() {
|
||||
this.updateMem(status)
|
||||
this.updateLoad(status)
|
||||
this.updateDisk(status)
|
||||
this.updateCacheSpace(status)
|
||||
status.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// 发送数据
|
||||
@@ -211,3 +215,25 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
"maxUsage": status.DiskMaxUsage,
|
||||
})
|
||||
}
|
||||
|
||||
// 缓存空间
|
||||
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
|
||||
var result = []maps.Map{}
|
||||
cachePaths := caches.SharedManager.FindAllCachePaths()
|
||||
for _, path := range cachePaths {
|
||||
var stat unix.Statfs_t
|
||||
err := unix.Statfs(path, &stat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
result = append(result, maps.Map{
|
||||
"path": path,
|
||||
"total": stat.Blocks * uint64(stat.Bsize),
|
||||
"avail": stat.Bavail * uint64(stat.Bsize),
|
||||
"used": (stat.Blocks - stat.Bavail) * uint64(stat.Bsize),
|
||||
})
|
||||
}
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCacheDir, maps.Map{
|
||||
"dirs": result,
|
||||
})
|
||||
}
|
||||
|
||||
12
internal/nodes/origin_state.go
Normal file
12
internal/nodes/origin_state.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
|
||||
type OriginState struct {
|
||||
CountFails int64
|
||||
UpdatedAt int64
|
||||
Config *serverconfigs.OriginConfig
|
||||
ReverseProxy *serverconfigs.ReverseProxyConfig
|
||||
}
|
||||
174
internal/nodes/origin_state_manager.go
Normal file
174
internal/nodes/origin_state_manager.go
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedOriginStateManager = NewOriginStateManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go SharedOriginStateManager.Start()
|
||||
})
|
||||
}
|
||||
|
||||
// OriginStateManager 源站状态管理
|
||||
type OriginStateManager struct {
|
||||
stateMap map[int64]*OriginState // originId => *OriginState
|
||||
|
||||
ticker *time.Ticker
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOriginStateManager 获取新管理对象
|
||||
func NewOriginStateManager() *OriginStateManager {
|
||||
return &OriginStateManager{
|
||||
stateMap: map[int64]*OriginState{},
|
||||
ticker: time.NewTicker(60 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动
|
||||
func (this *OriginStateManager) Start() {
|
||||
events.On(events.EventReload, func() {
|
||||
this.locker.Lock()
|
||||
this.stateMap = map[int64]*OriginState{}
|
||||
this.locker.Unlock()
|
||||
})
|
||||
|
||||
if Tea.IsTesting() {
|
||||
this.ticker = time.NewTicker(10 * time.Second)
|
||||
}
|
||||
for range this.ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("ORIGIN_MANAGER", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Loop 单次循环检查
|
||||
func (this *OriginStateManager) Loop() error {
|
||||
if sharedNodeConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var currentStates = []*OriginState{}
|
||||
this.locker.Lock()
|
||||
for originId, state := range this.stateMap {
|
||||
// 检查Origin是否正在使用
|
||||
config := sharedNodeConfig.FindOrigin(originId)
|
||||
if config == nil || !config.IsOn {
|
||||
delete(this.stateMap, originId)
|
||||
continue
|
||||
}
|
||||
state.Config = config
|
||||
currentStates = append(currentStates, state)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
if len(currentStates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var count = len(currentStates)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(count)
|
||||
for _, state := range currentStates {
|
||||
go func(state *OriginState) {
|
||||
defer wg.Done()
|
||||
conn, err := OriginConnect(state.Config, "")
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
|
||||
// 已经恢复正常
|
||||
this.locker.Lock()
|
||||
state.Config.IsOk = true
|
||||
delete(this.stateMap, state.Config.Id)
|
||||
this.locker.Unlock()
|
||||
|
||||
var reverseProxy = state.ReverseProxy
|
||||
if reverseProxy != nil {
|
||||
reverseProxy.ResetScheduling()
|
||||
}
|
||||
}
|
||||
}(state)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fail 添加失败的源站
|
||||
func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) {
|
||||
if origin == nil {
|
||||
return
|
||||
}
|
||||
this.locker.Lock()
|
||||
state, ok := this.stateMap[origin.Id]
|
||||
var timestamp = time.Now().Unix()
|
||||
if ok {
|
||||
if state.UpdatedAt < timestamp-300 { // N 分钟之后重新计数
|
||||
state.CountFails = 0
|
||||
state.Config.IsOk = true
|
||||
}
|
||||
|
||||
state.CountFails++
|
||||
state.Config = origin
|
||||
state.ReverseProxy = reverseProxy
|
||||
state.UpdatedAt = timestamp
|
||||
|
||||
if origin.IsOk {
|
||||
origin.IsOk = state.CountFails > 5 // 超过 N 次之后认为是异常
|
||||
|
||||
if !origin.IsOk {
|
||||
if callback != nil {
|
||||
callback()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.stateMap[origin.Id] = &OriginState{
|
||||
CountFails: 1,
|
||||
Config: origin,
|
||||
ReverseProxy: reverseProxy,
|
||||
UpdatedAt: timestamp,
|
||||
}
|
||||
origin.IsOk = true
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// Success 添加成功的源站
|
||||
func (this *OriginStateManager) Success(origin *serverconfigs.OriginConfig, callback func()) {
|
||||
if origin == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !origin.IsOk {
|
||||
if callback != nil {
|
||||
defer callback()
|
||||
}
|
||||
}
|
||||
|
||||
origin.IsOk = true
|
||||
this.locker.Lock()
|
||||
delete(this.stateMap, origin.Id)
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// IsAvailable 检查是否正常
|
||||
func (this *OriginStateManager) IsAvailable(originId int64) bool {
|
||||
this.locker.RLock()
|
||||
_, ok := this.stateMap[originId]
|
||||
this.locker.RUnlock()
|
||||
|
||||
return !ok
|
||||
}
|
||||
15
internal/nodes/origin_state_manager_test.go
Normal file
15
internal/nodes/origin_state_manager_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOriginManager_Loop(t *testing.T) {
|
||||
var manager = NewOriginStateManager()
|
||||
err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log(manager.stateMap)
|
||||
}
|
||||
@@ -16,41 +16,44 @@ func OriginConnect(origin *serverconfigs.OriginConfig, remoteAddr string) (net.C
|
||||
}
|
||||
|
||||
// 支持TOA的连接
|
||||
toaConfig := sharedTOAManager.Config()
|
||||
if toaConfig != nil && toaConfig.IsOn {
|
||||
retries := 3
|
||||
for i := 1; i <= retries; i++ {
|
||||
port := int(toaConfig.RandLocalPort())
|
||||
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("TOA", "add failed: "+err.Error())
|
||||
} else {
|
||||
dialer := net.Dialer{
|
||||
Timeout: origin.ConnTimeoutDuration(),
|
||||
LocalAddr: &net.TCPAddr{
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
var conn net.Conn
|
||||
switch origin.Addr.Protocol {
|
||||
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
conn, err = dialer.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange)
|
||||
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
// TODO 支持使用证书
|
||||
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
// 这个条件很重要,如果没有传递remoteAddr,表示不使用TOA
|
||||
if len(remoteAddr) > 0 {
|
||||
toaConfig := sharedTOAManager.Config()
|
||||
if toaConfig != nil && toaConfig.IsOn {
|
||||
retries := 3
|
||||
for i := 1; i <= retries; i++ {
|
||||
port := int(toaConfig.RandLocalPort())
|
||||
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("TOA", "add failed: "+err.Error())
|
||||
} else {
|
||||
dialer := net.Dialer{
|
||||
Timeout: origin.ConnTimeoutDuration(),
|
||||
LocalAddr: &net.TCPAddr{
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
var conn net.Conn
|
||||
switch origin.Addr.Protocol {
|
||||
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
conn, err = dialer.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange)
|
||||
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
// TODO 支持使用证书
|
||||
conn, err = tls.DialWithDialer(&dialer, "tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO 需要在合适的时机删除TOA记录
|
||||
if err == nil || i == retries {
|
||||
return conn, err
|
||||
// TODO 需要在合适的时机删除TOA记录
|
||||
if err == nil || i == retries {
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
@@ -8,8 +10,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -20,7 +26,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
// API节点同步任务
|
||||
// SyncAPINodesTask API节点同步任务
|
||||
type SyncAPINodesTask struct {
|
||||
}
|
||||
|
||||
@@ -74,6 +80,12 @@ func (this *SyncAPINodesTask) Loop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 测试是否有API节点可用
|
||||
hasOk := this.testEndpoints(newEndpoints)
|
||||
if !hasOk {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改RPC对象配置
|
||||
config.RPC.Endpoints = newEndpoints
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
@@ -95,3 +107,47 @@ func (this *SyncAPINodesTask) isSame(endpoints1 []string, endpoints2 []string) b
|
||||
sort.Strings(endpoints2)
|
||||
return strings.Join(endpoints1, "&") == strings.Join(endpoints2, "&")
|
||||
}
|
||||
|
||||
func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
|
||||
if len(endpoints) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var wg = sync.WaitGroup{}
|
||||
wg.Add(len(endpoints))
|
||||
|
||||
var ok = false
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
go func(endpoint string) {
|
||||
defer wg.Done()
|
||||
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer func() {
|
||||
cancel()
|
||||
}()
|
||||
var conn *grpc.ClientConn
|
||||
if u.Scheme == "http" {
|
||||
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithInsecure(), grpc.WithBlock())
|
||||
} else if u.Scheme == "https" {
|
||||
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})), grpc.WithBlock())
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
|
||||
ok = true
|
||||
}(endpoint)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
package nodes
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TrafficListener 用于统计流量的网络监听
|
||||
type TrafficListener struct {
|
||||
@@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) {
|
||||
go func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewTrafficConn(conn), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
|
||||
var sharedWAFManager = NewWAFManager()
|
||||
|
||||
// WAF管理器
|
||||
// WAFManager WAF管理器
|
||||
type WAFManager struct {
|
||||
mapping map[int64]*waf.WAF // policyId => WAF
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
// NewWAFManager 获取新对象
|
||||
func NewWAFManager() *WAFManager {
|
||||
return &WAFManager{
|
||||
mapping: map[int64]*waf.WAF{},
|
||||
}
|
||||
}
|
||||
|
||||
// 更新策略
|
||||
// UpdatePolicies 更新策略
|
||||
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
@@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
|
||||
this.mapping = m
|
||||
}
|
||||
|
||||
// 查找WAF
|
||||
// FindWAF 查找WAF
|
||||
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
|
||||
this.locker.RLock()
|
||||
w, _ := this.mapping[policyId]
|
||||
@@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
Action: set.Action,
|
||||
ActionOptions: set.ActionOptions,
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
}
|
||||
for _, a := range set.Actions {
|
||||
s.AddAction(a.Code, a.Options)
|
||||
}
|
||||
|
||||
// rules
|
||||
@@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
Action: set.Action,
|
||||
ActionOptions: set.ActionOptions,
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
}
|
||||
|
||||
for _, a := range set.Actions {
|
||||
s.AddAction(a.Code, a.Options)
|
||||
}
|
||||
|
||||
// rules
|
||||
@@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
|
||||
// action
|
||||
if policy.BlockOptions != nil {
|
||||
w.ActionBlock = &waf.BlockAction{
|
||||
w.DefaultBlockAction = &waf.BlockAction{
|
||||
StatusCode: policy.BlockOptions.StatusCode,
|
||||
Body: policy.BlockOptions.Body,
|
||||
URL: "",
|
||||
URL: policy.BlockOptions.URL,
|
||||
Timeout: policy.BlockOptions.Timeout,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,14 @@ func (this *RPCClient) ServerDailyStatRPC() pb.ServerDailyStatServiceClient {
|
||||
return pb.NewServerDailyStatServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient {
|
||||
return pb.NewMetricStatServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) FirewallService() pb.FirewallServiceClient {
|
||||
return pb.NewFirewallServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
// Context 节点上下文信息
|
||||
func (this *RPCClient) Context() context.Context {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"github.com/mssola/user_agent"
|
||||
@@ -30,6 +34,8 @@ type HTTPRequestStatManager struct {
|
||||
browserMap map[string]int64 // serverId@browser@version => count
|
||||
|
||||
dailyFirewallRuleGroupMap map[string]int64 // serverId@firewallRuleGroupId@action => count
|
||||
|
||||
totalAttackRequests int64
|
||||
}
|
||||
|
||||
// NewHTTPRequestStatManager 获取新对象
|
||||
@@ -38,16 +44,29 @@ func NewHTTPRequestStatManager() *HTTPRequestStatManager {
|
||||
ipChan: make(chan string, 10_000), // TODO 将来可以配置容量
|
||||
userAgentChan: make(chan string, 10_000), // TODO 将来可以配置容量
|
||||
firewallRuleGroupChan: make(chan string, 10_000), // TODO 将来可以配置容量
|
||||
cityMap: map[string]int64{},
|
||||
providerMap: map[string]int64{},
|
||||
systemMap: map[string]int64{},
|
||||
browserMap: map[string]int64{},
|
||||
cityMap: map[string]int64{},
|
||||
providerMap: map[string]int64{},
|
||||
systemMap: map[string]int64{},
|
||||
browserMap: map[string]int64{},
|
||||
dailyFirewallRuleGroupMap: map[string]int64{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动
|
||||
func (this *HTTPRequestStatManager) Start() {
|
||||
// 上传请求总数
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
if this.totalAttackRequests > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAttackRequests, maps.Map{"total": this.totalAttackRequests})
|
||||
this.totalAttackRequests = 0
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
loopTicker := time.NewTicker(1 * time.Second)
|
||||
uploadTicker := time.NewTicker(30 * time.Minute)
|
||||
if Tea.IsTesting() {
|
||||
@@ -114,14 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin
|
||||
}
|
||||
|
||||
// AddFirewallRuleGroupId 添加防火墙拦截动作
|
||||
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) {
|
||||
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) {
|
||||
if firewallRuleGroupId <= 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action:
|
||||
default:
|
||||
// 超出容量我们就丢弃
|
||||
|
||||
this.totalAttackRequests++
|
||||
|
||||
for _, action := range actions {
|
||||
select {
|
||||
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code:
|
||||
default:
|
||||
// 超出容量我们就丢弃
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +165,7 @@ Loop:
|
||||
ip := ipString[atIndex+1:]
|
||||
if iplibrary.SharedLibrary != nil {
|
||||
result, err := iplibrary.SharedLibrary.Lookup(ip)
|
||||
if err == nil {
|
||||
if err == nil && result != nil {
|
||||
this.cityMap[serverId+"@"+result.Country+"@"+result.Province+"@"+result.City] ++
|
||||
|
||||
if len(result.ISP) > 0 {
|
||||
|
||||
@@ -4,11 +4,15 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -20,19 +24,25 @@ type TrafficItem struct {
|
||||
CachedBytes int64
|
||||
CountRequests int64
|
||||
CountCachedRequests int64
|
||||
CountAttackRequests int64
|
||||
AttackBytes int64
|
||||
}
|
||||
|
||||
// TrafficStatManager 区域流量统计
|
||||
type TrafficStatManager struct {
|
||||
itemMap map[string]*TrafficItem // [timestamp serverId] => bytes
|
||||
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
|
||||
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
|
||||
locker sync.Mutex
|
||||
configFunc func() *nodeconfigs.NodeConfig
|
||||
|
||||
totalRequests int64
|
||||
}
|
||||
|
||||
// NewTrafficStatManager 获取新对象
|
||||
func NewTrafficStatManager() *TrafficStatManager {
|
||||
manager := &TrafficStatManager{
|
||||
itemMap: map[string]*TrafficItem{},
|
||||
itemMap: map[string]*TrafficItem{},
|
||||
domainsMap: map[string]*TrafficItem{},
|
||||
}
|
||||
|
||||
return manager
|
||||
@@ -42,6 +52,20 @@ func NewTrafficStatManager() *TrafficStatManager {
|
||||
func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) {
|
||||
this.configFunc = configFunc
|
||||
|
||||
// 上传请求总数
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
if this.totalRequests > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemRequests, maps.Map{"total": this.totalRequests})
|
||||
this.totalRequests = 0
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
// 上传统计数据
|
||||
duration := 5 * time.Minute
|
||||
if Tea.IsTesting() {
|
||||
// 测试环境缩短上传时间,方便我们调试
|
||||
@@ -62,15 +86,19 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
|
||||
}
|
||||
|
||||
// Add 添加流量
|
||||
func (this *TrafficStatManager) Add(serverId int64, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64) {
|
||||
func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
|
||||
if bytes == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
this.totalRequests++
|
||||
|
||||
timestamp := utils.UnixTime() / 300 * 300
|
||||
|
||||
key := strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
|
||||
this.locker.Lock()
|
||||
|
||||
// 总的流量
|
||||
item, ok := this.itemMap[key]
|
||||
if !ok {
|
||||
item = &TrafficItem{}
|
||||
@@ -80,6 +108,23 @@ func (this *TrafficStatManager) Add(serverId int64, bytes int64, cachedBytes int
|
||||
item.CachedBytes += cachedBytes
|
||||
item.CountRequests += countRequests
|
||||
item.CountCachedRequests += countCachedRequests
|
||||
item.CountAttackRequests += countAttacks
|
||||
item.AttackBytes += attackBytes
|
||||
|
||||
// 单个域名流量
|
||||
var domainKey = strconv.FormatInt(timestamp, 10) + "@" + strconv.FormatInt(serverId, 10) + "@" + domain
|
||||
domainItem, ok := this.domainsMap[domainKey]
|
||||
if !ok {
|
||||
domainItem = &TrafficItem{}
|
||||
this.domainsMap[domainKey] = domainItem
|
||||
}
|
||||
domainItem.Bytes += bytes
|
||||
domainItem.CachedBytes += cachedBytes
|
||||
domainItem.CountRequests += countRequests
|
||||
domainItem.CountCachedRequests += countCachedRequests
|
||||
domainItem.CountAttackRequests += countAttacks
|
||||
domainItem.AttackBytes += attackBytes
|
||||
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
@@ -96,12 +141,15 @@ func (this *TrafficStatManager) Upload() error {
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
m := this.itemMap
|
||||
itemMap := this.itemMap
|
||||
domainMap := this.domainsMap
|
||||
this.itemMap = map[string]*TrafficItem{}
|
||||
this.domainsMap = map[string]*TrafficItem{}
|
||||
this.locker.Unlock()
|
||||
|
||||
pbStats := []*pb.ServerDailyStat{}
|
||||
for key, item := range m {
|
||||
// 服务统计
|
||||
var pbServerStats = []*pb.ServerDailyStat{}
|
||||
for key, item := range itemMap {
|
||||
timestamp, err := strconv.ParseInt(key[:10], 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -111,19 +159,45 @@ func (this *TrafficStatManager) Upload() error {
|
||||
return err
|
||||
}
|
||||
|
||||
pbStats = append(pbStats, &pb.ServerDailyStat{
|
||||
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
|
||||
ServerId: serverId,
|
||||
RegionId: config.RegionId,
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CreatedAt: timestamp,
|
||||
})
|
||||
}
|
||||
if len(pbStats) == 0 {
|
||||
if len(pbServerStats) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err = client.ServerDailyStatRPC().UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{Stats: pbStats})
|
||||
|
||||
// 域名统计
|
||||
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
|
||||
for key, item := range domainMap {
|
||||
var pieces = strings.SplitN(key, "@", 3)
|
||||
if len(pieces) != 3 {
|
||||
continue
|
||||
}
|
||||
pbDomainStats = append(pbDomainStats, &pb.UploadServerDailyStatsRequest_DomainStat{
|
||||
ServerId: types.Int64(pieces[1]),
|
||||
Domain: pieces[2],
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
CountCachedRequests: item.CountCachedRequests,
|
||||
CountAttackRequests: item.CountAttackRequests,
|
||||
AttackBytes: item.AttackBytes,
|
||||
CreatedAt: types.Int64(pieces[0]),
|
||||
})
|
||||
}
|
||||
|
||||
_, err = client.ServerDailyStatRPC().UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{
|
||||
Stats: pbServerStats,
|
||||
DomainStats: pbDomainStats,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
func TestTrafficStatManager_Add(t *testing.T) {
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < 100; i++ {
|
||||
manager.Add(1, 10, 1, 0)
|
||||
manager.Add(1, "goedge.cn", 1, 0, 0, 0)
|
||||
}
|
||||
t.Log(manager.itemMap)
|
||||
}
|
||||
@@ -16,7 +16,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
|
||||
func TestTrafficStatManager_Upload(t *testing.T) {
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < 100; i++ {
|
||||
manager.Add(1, 10, 1, 0)
|
||||
manager.Add(1, "goedge.cn", 1, 0, 0, 0)
|
||||
}
|
||||
err := manager.Upload()
|
||||
if err != nil {
|
||||
@@ -30,6 +30,6 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) {
|
||||
|
||||
manager := NewTrafficStatManager()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Add(1, 1024, 1, 0)
|
||||
manager.Add(1, "goedge.cn", 1024, 1, 0, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func (this *Piece) Add(key uint64, item *Item) () {
|
||||
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64) (result int64) {
|
||||
this.locker.Lock()
|
||||
item, ok := this.m[key]
|
||||
if ok {
|
||||
if ok && item.expiredAt > time.Now().Unix() {
|
||||
result = types.Int64(item.Value) + delta
|
||||
item.Value = result
|
||||
item.expiredAt = expiredAt
|
||||
|
||||
159
internal/utils/encrypt.go
Normal file
159
internal/utils/encrypt.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
)
|
||||
|
||||
var (
|
||||
simpleEncryptMagicKey = rands.HexString(32)
|
||||
)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SimpleEncrypt 加密特殊信息
|
||||
func SimpleEncrypt(data []byte) []byte {
|
||||
var method = &AES256CFBMethod{}
|
||||
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
|
||||
if err != nil {
|
||||
logs.Println("[SimpleEncrypt]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
dst, err := method.Encrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[SimpleEncrypt]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// SimpleDecrypt 解密特殊信息
|
||||
func SimpleDecrypt(data []byte) []byte {
|
||||
var method = &AES256CFBMethod{}
|
||||
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
src, err := method.Decrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
||||
func SimpleEncryptMap(m maps.Map) (base64String string, err error) {
|
||||
mJSON, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data := SimpleEncrypt(mJSON)
|
||||
return base64.StdEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
func SimpleDecryptMap(base64String string) (maps.Map, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mJSON := SimpleDecrypt(data)
|
||||
var result = maps.Map{}
|
||||
err = json.Unmarshal(mJSON, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type AES256CFBMethod struct {
|
||||
block cipher.Block
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Init(key, iv []byte) error {
|
||||
// 判断key是否为32长度
|
||||
l := len(key)
|
||||
if l > 32 {
|
||||
key = key[:32]
|
||||
} else if l < 32 {
|
||||
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.block = block
|
||||
|
||||
// 判断iv长度
|
||||
l2 := len(iv)
|
||||
if l2 > aes.BlockSize {
|
||||
iv = iv[:aes.BlockSize]
|
||||
} else if l2 < aes.BlockSize {
|
||||
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
|
||||
}
|
||||
this.iv = iv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r != nil {
|
||||
err = errors.New("encrypt failed")
|
||||
}
|
||||
}()
|
||||
|
||||
dst = make([]byte, len(src))
|
||||
|
||||
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(dst, src)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r != nil {
|
||||
err = errors.New("decrypt failed")
|
||||
}
|
||||
}()
|
||||
|
||||
src = make([]byte, len(dst))
|
||||
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
|
||||
decrypter.XORKeyStream(src, dst)
|
||||
|
||||
return
|
||||
}
|
||||
52
internal/utils/encrypt_test.go
Normal file
52
internal/utils/encrypt_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSimpleEncrypt(t *testing.T) {
|
||||
var arr = []string{"Hello", "World", "People"}
|
||||
for _, s := range arr {
|
||||
var value = []byte(s)
|
||||
encoded := SimpleEncrypt(value)
|
||||
t.Log(encoded, string(encoded))
|
||||
decoded := SimpleDecrypt(encoded)
|
||||
t.Log(decoded, string(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEncrypt_Concurrent(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
var arr = []string{"Hello", "World", "People"}
|
||||
wg.Add(len(arr))
|
||||
for _, s := range arr {
|
||||
go func(s string) {
|
||||
defer wg.Done()
|
||||
t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s)))))
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSimpleEncryptMap(t *testing.T) {
|
||||
var m = maps.Map{
|
||||
"s": "Hello",
|
||||
"i": 20,
|
||||
"b": true,
|
||||
}
|
||||
encodedResult, err := SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("result:", encodedResult)
|
||||
|
||||
decodedResult, err := SimpleDecryptMap(encodedResult)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(decodedResult)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type List struct {
|
||||
itemsMap map[int64]int64 // itemId => timestamp
|
||||
|
||||
locker sync.Mutex
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewList() *List {
|
||||
@@ -21,10 +22,7 @@ func NewList() *List {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *List) Add(itemId int64, expiredAt int64) {
|
||||
if expiredAt <= time.Now().Unix() {
|
||||
return
|
||||
}
|
||||
func (this *List) Add(itemId int64, expiresAt int64) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
@@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) {
|
||||
this.removeItem(itemId)
|
||||
}
|
||||
|
||||
expireItemMap, ok := this.expireMap[expiredAt]
|
||||
expireItemMap, ok := this.expireMap[expiresAt]
|
||||
if ok {
|
||||
expireItemMap[itemId] = true
|
||||
} else {
|
||||
expireItemMap = ItemMap{
|
||||
itemId: true,
|
||||
}
|
||||
this.expireMap[expiredAt] = expireItemMap
|
||||
this.expireMap[expiresAt] = expireItemMap
|
||||
}
|
||||
|
||||
this.itemsMap[itemId] = expiredAt
|
||||
this.itemsMap[itemId] = expiresAt
|
||||
}
|
||||
|
||||
func (this *List) Remove(itemId int64) {
|
||||
@@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
|
||||
}
|
||||
|
||||
func (this *List) StartGC(callback func(itemId int64)) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
this.ticker = time.NewTicker(1 * time.Second)
|
||||
lastTimestamp := int64(0)
|
||||
for range ticker.C {
|
||||
for range this.ticker.C {
|
||||
timestamp := time.Now().Unix()
|
||||
if lastTimestamp == 0 {
|
||||
lastTimestamp = timestamp - 3600
|
||||
}
|
||||
|
||||
// 防止死循环
|
||||
if lastTimestamp > timestamp {
|
||||
continue
|
||||
}
|
||||
|
||||
for i := lastTimestamp; i <= timestamp; i++ {
|
||||
this.GC(timestamp, callback)
|
||||
if timestamp >= lastTimestamp {
|
||||
for i := lastTimestamp; i <= timestamp; i++ {
|
||||
this.GC(i, callback)
|
||||
}
|
||||
} else {
|
||||
for i := timestamp; i <= lastTimestamp; i++ {
|
||||
this.GC(i, callback)
|
||||
}
|
||||
}
|
||||
|
||||
// 这样做是为了防止系统时钟突变
|
||||
|
||||
@@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) {
|
||||
list.Add(2, time.Now().Unix()+1)
|
||||
list.Add(3, time.Now().Unix()+2)
|
||||
list.Add(4, time.Now().Unix()+5)
|
||||
list.Add(5, time.Now().Unix()+5)
|
||||
list.Add(6, time.Now().Unix()+6)
|
||||
list.Add(7, time.Now().Unix()+6)
|
||||
list.Add(8, time.Now().Unix()+6)
|
||||
|
||||
go func() {
|
||||
list.StartGC(func(itemId int64) {
|
||||
@@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
time.Sleep(20 * time.Second)
|
||||
}
|
||||
|
||||
func TestList_ManyItems(t *testing.T) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 将IP转换为整型
|
||||
// IP2Long 将IP转换为整型
|
||||
// 注意IPv6没有顺序
|
||||
func IP2Long(ip string) uint64 {
|
||||
if len(ip) == 0 {
|
||||
|
||||
35
internal/utils/jsonutils/map.go
Normal file
35
internal/utils/jsonutils/map.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
func MapToObject(m maps.Map, ptr interface{}) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
mJSON, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(mJSON, ptr)
|
||||
}
|
||||
|
||||
func ObjectToMap(ptr interface{}) (maps.Map, error) {
|
||||
if ptr == nil {
|
||||
return maps.Map{}, nil
|
||||
}
|
||||
ptrJSON, err := json.Marshal(ptr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result = maps.Map{}
|
||||
err = json.Unmarshal(ptrJSON, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
46
internal/utils/jsonutils/map_test.go
Normal file
46
internal/utils/jsonutils/map_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMapToObject(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
type typeA struct {
|
||||
B int `json:"b"`
|
||||
C bool `json:"c"`
|
||||
}
|
||||
|
||||
{
|
||||
var obj = &typeA{B: 1, C: true}
|
||||
m, err := ObjectToMap(obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
PrintT(m, t)
|
||||
a.IsTrue(m.GetInt("b") == 1)
|
||||
a.IsTrue(m.GetBool("c") == true)
|
||||
}
|
||||
|
||||
{
|
||||
var obj = &typeA{}
|
||||
err := MapToObject(maps.Map{
|
||||
"b": 1024,
|
||||
"c": true,
|
||||
}, obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if obj == nil {
|
||||
t.Fatal("obj should not be nil")
|
||||
}
|
||||
a.IsTrue(obj.B == 1024)
|
||||
a.IsTrue(obj.C == true)
|
||||
PrintT(obj, t)
|
||||
}
|
||||
}
|
||||
17
internal/utils/jsonutils/utils.go
Normal file
17
internal/utils/jsonutils/utils.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func PrintT(obj interface{}, t *testing.T) {
|
||||
data, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
} else {
|
||||
t.Log(string(data))
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,23 @@ import (
|
||||
type AllowAction struct {
|
||||
}
|
||||
|
||||
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *AllowAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AllowAction) Code() string {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
func (this *AllowAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *AllowAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
21
internal/waf/action_base.go
Normal file
21
internal/waf/action_base.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "net/http"
|
||||
|
||||
type BaseAction struct {
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
|
||||
// 断开连接
|
||||
hijack, ok := writer.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, err := hijack.Hijack()
|
||||
if err == nil {
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -23,12 +23,48 @@ type BlockAction struct {
|
||||
StatusCode int `yaml:"statusCode" json:"statusCode"`
|
||||
Body string `yaml:"body" json:"body"` // supports HTML
|
||||
URL string `yaml:"url" json:"url"`
|
||||
Timeout int32 `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *BlockAction) Init(waf *WAF) error {
|
||||
if waf.DefaultBlockAction != nil {
|
||||
if this.StatusCode <= 0 {
|
||||
this.StatusCode = waf.DefaultBlockAction.StatusCode
|
||||
}
|
||||
if len(this.Body) == 0 {
|
||||
this.Body = waf.DefaultBlockAction.Body
|
||||
}
|
||||
if len(this.URL) == 0 {
|
||||
this.URL = waf.DefaultBlockAction.URL
|
||||
}
|
||||
if this.Timeout <= 0 {
|
||||
this.Timeout = waf.DefaultBlockAction.Timeout
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *BlockAction) Code() string {
|
||||
return ActionBlock
|
||||
}
|
||||
|
||||
func (this *BlockAction) IsAttack() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *BlockAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
if this.Timeout > 0 {
|
||||
// 加入到黑名单
|
||||
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout))
|
||||
}
|
||||
|
||||
if writer != nil {
|
||||
// if status code eq 444, we close the connection
|
||||
if this.StatusCode == 444 {
|
||||
// close the connection
|
||||
defer func() {
|
||||
hijack, ok := writer.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, _ := hijack.Hijack()
|
||||
@@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// output response
|
||||
if this.StatusCode > 0 {
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32)
|
||||
|
||||
const (
|
||||
CaptchaSeconds = 600 // 10 minutes
|
||||
CaptchaPath = "/WAF/VERIFY/CAPTCHA"
|
||||
)
|
||||
|
||||
type CaptchaAction struct {
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ...
|
||||
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// TEAWEB_CAPTCHA:
|
||||
cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
|
||||
if err == nil && cookie != nil && len(cookie.Value) > 32 {
|
||||
m := cookie.Value[:32]
|
||||
timestamp := cookie.Value[32:]
|
||||
if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
|
||||
return true
|
||||
func (this *CaptchaAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Code() string {
|
||||
return ActionCaptcha
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
refURL := request.WAFRaw().URL.String()
|
||||
|
||||
// 覆盖配置
|
||||
if strings.HasPrefix(refURL, CaptchaPath) {
|
||||
info := request.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) > 0 {
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err == nil && m != nil {
|
||||
refURL = m.GetString("url")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
refURL := request.URL.String()
|
||||
if len(request.Referer()) > 0 {
|
||||
refURL = request.Referer()
|
||||
var captchaConfig = maps.Map{
|
||||
"action": this,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"url": refURL,
|
||||
"setId": set.Id,
|
||||
}
|
||||
http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
13
internal/waf/action_category.go
Normal file
13
internal/waf/action_category.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
|
||||
type ActionCategory = string
|
||||
|
||||
const (
|
||||
ActionCategoryAllow ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow
|
||||
ActionCategoryBlock ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock
|
||||
ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify
|
||||
)
|
||||
10
internal/waf/action_config.go
Normal file
10
internal/waf/action_config.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "github.com/iwind/TeaGo/maps"
|
||||
|
||||
type ActionConfig struct {
|
||||
Code string `yaml:"code" json:"code"`
|
||||
Options maps.Map `yaml:"options" json:"options"`
|
||||
}
|
||||
@@ -2,11 +2,12 @@ package waf
|
||||
|
||||
import "reflect"
|
||||
|
||||
// action definition
|
||||
// ActionDefinition action definition
|
||||
type ActionDefinition struct {
|
||||
Name string
|
||||
Code ActionString
|
||||
Description string
|
||||
Category string // category: block, verify, allow
|
||||
Instance ActionInterface
|
||||
Type reflect.Type
|
||||
}
|
||||
|
||||
73
internal/waf/action_get_302.go
Normal file
73
internal/waf/action_get_302.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
Get302Path = "/WAF/VERIFY/GET"
|
||||
)
|
||||
|
||||
// Get302Action
|
||||
// 原理: origin url --> 302 verify url --> origin url
|
||||
// TODO 将来支持meta refresh验证
|
||||
type Get302Action struct {
|
||||
BaseAction
|
||||
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
}
|
||||
|
||||
func (this *Get302Action) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Get302Action) Code() string {
|
||||
return ActionGet302
|
||||
}
|
||||
|
||||
func (this *Get302Action) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *Get302Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 仅限于Get
|
||||
if request.WAFRaw().Method != http.MethodGet {
|
||||
return true
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
var m = maps.Map{
|
||||
"url": request.WAFRaw().URL.String(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
|
||||
|
||||
// 关闭连接
|
||||
if request.WAFRaw().ProtoMajor == 1 {
|
||||
_ = this.CloseConn(writer)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -10,13 +10,29 @@ type GoGroupAction struct {
|
||||
GroupId string `yaml:"groupId" json:"groupId"`
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
group := waf.FindRuleGroup(this.GroupId)
|
||||
if group == nil || !group.IsOn {
|
||||
func (this *GoGroupAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Code() string {
|
||||
return ActionGoGroup
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
b, set, err := group.MatchRequest(request)
|
||||
b, nextSet, err := nextGroup.MatchRequest(request)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return true
|
||||
@@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h
|
||||
return true
|
||||
}
|
||||
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true
|
||||
}
|
||||
return actionObject.Perform(waf, request, writer)
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
}
|
||||
|
||||
@@ -11,17 +11,33 @@ type GoSetAction struct {
|
||||
SetId string `yaml:"setId" json:"setId"`
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
group := waf.FindRuleGroup(this.GroupId)
|
||||
if group == nil || !group.IsOn {
|
||||
func (this *GoSetAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Code() string {
|
||||
return ActionGoSet
|
||||
}
|
||||
|
||||
func (this *GoSetAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *GoSetAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
set := group.FindRuleSet(this.SetId)
|
||||
if set == nil || !set.IsOn {
|
||||
nextSet := nextGroup.FindRuleSet(this.SetId)
|
||||
if nextSet == nil || !nextSet.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
b, err := set.MatchRequest(request)
|
||||
b, err := nextSet.MatchRequest(request)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return true
|
||||
@@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt
|
||||
if !b {
|
||||
return true
|
||||
}
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true
|
||||
}
|
||||
return actionObject.Perform(waf, request, writer)
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
}
|
||||
|
||||
25
internal/waf/action_interface.go
Normal file
25
internal/waf/action_interface.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ActionInterface interface {
|
||||
// Init 初始化
|
||||
Init(waf *WAF) error
|
||||
|
||||
// Code 代号
|
||||
Code() string
|
||||
|
||||
// IsAttack 是否为拦截攻击动作
|
||||
IsAttack() bool
|
||||
|
||||
// WillChange determine if the action will change the request
|
||||
WillChange() bool
|
||||
|
||||
// Perform perform the action
|
||||
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
|
||||
}
|
||||
@@ -8,6 +8,22 @@ import (
|
||||
type LogAction struct {
|
||||
}
|
||||
|
||||
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *LogAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *LogAction) Code() string {
|
||||
return ActionLog
|
||||
}
|
||||
|
||||
func (this *LogAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *LogAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
86
internal/waf/action_notify.go
Normal file
86
internal/waf/action_notify.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type notifyTask struct {
|
||||
ServerId int64
|
||||
HttpFirewallPolicyId int64
|
||||
HttpFirewallRuleGroupId int64
|
||||
HttpFirewallRuleSetId int64
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
var notifyChan = make(chan *notifyTask, 128)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for task := range notifyChan {
|
||||
_, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{
|
||||
ServerId: task.ServerId,
|
||||
HttpFirewallPolicyId: task.HttpFirewallPolicyId,
|
||||
HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId,
|
||||
HttpFirewallRuleSetId: task.HttpFirewallRuleSetId,
|
||||
CreatedAt: task.CreatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
type NotifyAction struct {
|
||||
}
|
||||
|
||||
func (this *NotifyAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *NotifyAction) Code() string {
|
||||
return ActionNotify
|
||||
}
|
||||
|
||||
func (this *NotifyAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// WillChange determine if the action will change the request
|
||||
func (this *NotifyAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Perform perform the action
|
||||
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
select {
|
||||
case notifyChan <- ¬ifyTask{
|
||||
ServerId: request.WAFServerId(),
|
||||
HttpFirewallPolicyId: types.Int64(waf.Id),
|
||||
HttpFirewallRuleGroupId: types.Int64(group.Id),
|
||||
HttpFirewallRuleSetId: types.Int64(set.Id),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
89
internal/waf/action_post_307.go
Normal file
89
internal/waf/action_post_307.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Post307Action struct {
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
|
||||
BaseAction
|
||||
}
|
||||
|
||||
func (this *Post307Action) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Post307Action) Code() string {
|
||||
return ActionPost307
|
||||
}
|
||||
|
||||
func (this *Post307Action) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *Post307Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
var cookieName = "WAF_VALIDATOR_ID"
|
||||
|
||||
// 仅限于POST
|
||||
if request.WAFRaw().Method != http.MethodPost {
|
||||
return true
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 判断是否有Cookie
|
||||
cookie, err := request.WAFRaw().Cookie(cookieName)
|
||||
if err == nil && cookie != nil {
|
||||
m, err := utils.SimpleDecryptMap(cookie.Value)
|
||||
if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 {
|
||||
var life = m.GetInt64("life")
|
||||
if life <= 0 {
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
var setId = m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var m = maps.Map{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"setId": set.Id,
|
||||
"remoteIP": request.WAFRemoteIP(),
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
// 设置Cookie
|
||||
http.SetCookie(writer, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Path: "/",
|
||||
MaxAge: 10,
|
||||
Value: info,
|
||||
})
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
if request.WAFRaw().ProtoMajor == 1 {
|
||||
_ = this.CloseConn(writer)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
120
internal/waf/action_record_ip.go
Normal file
120
internal/waf/action_record_ip.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type recordIPTask struct {
|
||||
ip string
|
||||
listId int64
|
||||
expiredAt int64
|
||||
level string
|
||||
}
|
||||
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 1024)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for task := range recordIPTaskChan {
|
||||
ipType := "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiredAt,
|
||||
Reason: "触发WAF规则自动加入",
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
type RecordIPAction struct {
|
||||
BaseAction
|
||||
|
||||
Type string `yaml:"type" json:"type"`
|
||||
IPListId int64 `yaml:"ipListId" json:"ipListId"`
|
||||
Level string `yaml:"level" json:"level"`
|
||||
Timeout int32 `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Code() string {
|
||||
return ActionRecordIP
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) IsAttack() bool {
|
||||
return this.Type == "black"
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) WillChange() bool {
|
||||
return this.Type == "black"
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在本地白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 先加入本地的黑名单
|
||||
timeout := this.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 86400 // 1天
|
||||
}
|
||||
expiredAt := time.Now().Unix() + int64(timeout)
|
||||
|
||||
if this.Type == "black" {
|
||||
_ = this.CloseConn(writer)
|
||||
|
||||
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
|
||||
} else {
|
||||
// 加入本地白名单
|
||||
timeout := this.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 86400 // 1天
|
||||
}
|
||||
SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
|
||||
}
|
||||
|
||||
// 上报
|
||||
if this.IPListId > 0 {
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
level: this.Level,
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return this.Type != "black"
|
||||
}
|
||||
30
internal/waf/action_tag.go
Normal file
30
internal/waf/action_tag.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TagAction struct {
|
||||
Tags []string `yaml:"tags" json:"tags"`
|
||||
}
|
||||
|
||||
func (this *TagAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TagAction) Code() string {
|
||||
return ActionTag
|
||||
}
|
||||
|
||||
func (this *TagAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *TagAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
return true
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ActionString = string
|
||||
|
||||
const (
|
||||
ActionLog = "log" // allow and log
|
||||
ActionBlock = "block" // block
|
||||
ActionCaptcha = "captcha" // block and show captcha
|
||||
ActionAllow = "allow" // allow
|
||||
ActionGoGroup = "go_group" // go to next rule group
|
||||
ActionGoSet = "go_set" // go to next rule set
|
||||
)
|
||||
|
||||
type ActionInterface interface {
|
||||
Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
|
||||
}
|
||||
88
internal/waf/action_types.go
Normal file
88
internal/waf/action_types.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package waf
|
||||
|
||||
import "reflect"
|
||||
|
||||
type ActionString = string
|
||||
|
||||
const (
|
||||
ActionLog ActionString = "log" // allow and log
|
||||
ActionBlock ActionString = "block" // block
|
||||
ActionCaptcha ActionString = "captcha" // block and show captcha
|
||||
ActionNotify ActionString = "notify" // 告警
|
||||
ActionGet302 ActionString = "get_302" // 针对GET的302重定向认证
|
||||
ActionPost307 ActionString = "post_307" // 针对POST的307重定向认证
|
||||
ActionRecordIP ActionString = "record_ip" // 记录IP
|
||||
ActionTag ActionString = "tag" // 标签
|
||||
ActionAllow ActionString = "allow" // allow
|
||||
ActionGoGroup ActionString = "go_group" // go to next rule group
|
||||
ActionGoSet ActionString = "go_set" // go to next rule set
|
||||
)
|
||||
|
||||
var AllActions = []*ActionDefinition{
|
||||
{
|
||||
Name: "阻止",
|
||||
Code: ActionBlock,
|
||||
Instance: new(BlockAction),
|
||||
Type: reflect.TypeOf(new(BlockAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "允许通过",
|
||||
Code: ActionAllow,
|
||||
Instance: new(AllowAction),
|
||||
Type: reflect.TypeOf(new(AllowAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "允许并记录日志",
|
||||
Code: ActionLog,
|
||||
Instance: new(LogAction),
|
||||
Type: reflect.TypeOf(new(LogAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "Captcha验证码",
|
||||
Code: ActionCaptcha,
|
||||
Instance: new(CaptchaAction),
|
||||
Type: reflect.TypeOf(new(CaptchaAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "告警",
|
||||
Code: ActionNotify,
|
||||
Instance: new(NotifyAction),
|
||||
Type: reflect.TypeOf(new(NotifyAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "GET 302",
|
||||
Code: ActionGet302,
|
||||
Instance: new(Get302Action),
|
||||
Type: reflect.TypeOf(new(Get302Action)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "POST 307",
|
||||
Code: ActionPost307,
|
||||
Instance: new(Post307Action),
|
||||
Type: reflect.TypeOf(new(Post307Action)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "记录IP",
|
||||
Code: ActionRecordIP,
|
||||
Instance: new(RecordIPAction),
|
||||
Type: reflect.TypeOf(new(RecordIPAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "标签",
|
||||
Code: ActionTag,
|
||||
Instance: new(TagAction),
|
||||
Type: reflect.TypeOf(new(TagAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则分组",
|
||||
Code: ActionGoGroup,
|
||||
Instance: new(GoGroupAction),
|
||||
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则集",
|
||||
Code: ActionGoSet,
|
||||
Instance: new(GoSetAction),
|
||||
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
|
||||
},
|
||||
}
|
||||
@@ -1,45 +1,12 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var AllActions = []*ActionDefinition{
|
||||
{
|
||||
Name: "阻止",
|
||||
Code: ActionBlock,
|
||||
Instance: new(BlockAction),
|
||||
},
|
||||
{
|
||||
Name: "允许通过",
|
||||
Code: ActionAllow,
|
||||
Instance: new(AllowAction),
|
||||
},
|
||||
{
|
||||
Name: "允许并记录日志",
|
||||
Code: ActionLog,
|
||||
Instance: new(LogAction),
|
||||
},
|
||||
{
|
||||
Name: "Captcha验证码",
|
||||
Code: ActionCaptcha,
|
||||
Instance: new(CaptchaAction),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则分组",
|
||||
Code: ActionGoGroup,
|
||||
Instance: new(GoGroupAction),
|
||||
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则集",
|
||||
Code: ActionGoSet,
|
||||
Instance: new(GoSetAction),
|
||||
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
|
||||
},
|
||||
}
|
||||
|
||||
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||
for _, def := range AllActions {
|
||||
if def.Code == action {
|
||||
@@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||
instance := ptrValue.Interface().(ActionInterface)
|
||||
|
||||
if len(options) > 0 {
|
||||
count := def.Type.NumField()
|
||||
for i := 0; i < count; i++ {
|
||||
field := def.Type.Field(i)
|
||||
tag, ok := field.Tag.Lookup("yaml")
|
||||
if ok {
|
||||
v, ok := options[tag]
|
||||
if ok && reflect.TypeOf(v) == field.Type {
|
||||
ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
|
||||
}
|
||||
optionsJSON, err := json.Marshal(options)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error())
|
||||
} else {
|
||||
err = json.Unmarshal(optionsJSON, instance)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package waf
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) {
|
||||
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
|
||||
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
|
||||
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
|
||||
|
||||
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
|
||||
}
|
||||
|
||||
func TestFindActionInstance_Options(t *testing.T) {
|
||||
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
|
||||
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
|
||||
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
|
||||
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
|
||||
"timeout": 3600,
|
||||
}), t)
|
||||
}
|
||||
|
||||
func BenchmarkFindActionInstance(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -3,29 +3,64 @@ package waf
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/dchest/captcha"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var captchaValidator = &CaptchaValidator{}
|
||||
var captchaValidator = NewCaptchaValidator()
|
||||
|
||||
type CaptchaValidator struct {
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
|
||||
if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(request, writer)
|
||||
func NewCaptchaValidator() *CaptchaValidator {
|
||||
return &CaptchaValidator{}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
|
||||
var info = request.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err != nil {
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
timestamp := m.GetInt64("timestamp")
|
||||
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var actionConfig = &CaptchaAction{}
|
||||
err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
|
||||
if err != nil {
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var setId = m.GetInt64("setId")
|
||||
var originURL = m.GetString("url")
|
||||
|
||||
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(actionConfig, setId, originURL, request, writer)
|
||||
} else {
|
||||
this.show(request, writer)
|
||||
this.show(actionConfig, request, writer)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
|
||||
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
|
||||
// show captcha
|
||||
captchaId := captcha.NewLen(6)
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
@@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon
|
||||
return
|
||||
}
|
||||
|
||||
var lang = actionConfig.Language
|
||||
if len(lang) == 0 {
|
||||
acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
|
||||
if len(acceptLanguage) > 0 {
|
||||
langIndex := strings.Index(acceptLanguage, ",")
|
||||
if langIndex > 0 {
|
||||
lang = acceptLanguage[:langIndex]
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(lang) == 0 {
|
||||
lang = "en-US"
|
||||
}
|
||||
|
||||
var msgTitle = ""
|
||||
var msgPrompt = ""
|
||||
var msgButtonTitle = ""
|
||||
|
||||
switch lang {
|
||||
case "en-US":
|
||||
msgTitle = "Verify Yourself"
|
||||
msgPrompt = "Input verify code above:"
|
||||
msgButtonTitle = "Verify Yourself"
|
||||
case "zh-CN":
|
||||
msgTitle = "身份验证"
|
||||
msgPrompt = "请输入上面的验证码"
|
||||
msgButtonTitle = "提交验证"
|
||||
default:
|
||||
msgTitle = "Verify Yourself"
|
||||
msgPrompt = "Input verify code above:"
|
||||
msgButtonTitle = "Verify Yourself"
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = writer.Write([]byte(`<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Verify Yourself</title>
|
||||
<title>` + msgTitle + `</title>
|
||||
<script type="text/javascript">
|
||||
if (window.addEventListener != null) {
|
||||
window.addEventListener("load", function () {
|
||||
document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus()
|
||||
})
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST">
|
||||
<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
|
||||
<div>
|
||||
<p>Input verify code above:</p>
|
||||
<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
|
||||
<p>` + msgPrompt + `</p>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px; width: 160px"/>
|
||||
</div>
|
||||
<div>
|
||||
<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
|
||||
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>`))
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
if len(captchaId) > 0 {
|
||||
captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
|
||||
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
if captcha.VerifyString(captchaId, captchaCode) {
|
||||
// set cookie
|
||||
timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
|
||||
m := stringutil.Md5(captchaSalt + timestamp)
|
||||
http.SetCookie(writer, &http.Cookie{
|
||||
Name: "TEAWEB_WAF_CAPTCHA",
|
||||
Value: m + timestamp,
|
||||
MaxAge: CaptchaSeconds, // TODO 这个时间可以设置
|
||||
Path: "/", // all of dirs
|
||||
})
|
||||
var life = CaptchaSeconds
|
||||
if actionConfig.Life > 0 {
|
||||
life = types.Int(actionConfig.Life)
|
||||
}
|
||||
|
||||
rawURL := request.URL.Query().Get("url")
|
||||
http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
|
||||
// 加入到白名单
|
||||
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
|
||||
return false
|
||||
} else {
|
||||
http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,14 +5,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ${cc.arg}
|
||||
// CCCheckpoint ${cc.arg}
|
||||
// TODO implement more traffic rules
|
||||
type CCCheckpoint struct {
|
||||
Checkpoint
|
||||
@@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() {
|
||||
this.cache = ttlcache.NewCache()
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = 0
|
||||
|
||||
if this.cache == nil {
|
||||
@@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
var key = ""
|
||||
switch userType {
|
||||
case "ip":
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
case "cookie":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
cookie, _ := req.Cookie(userField)
|
||||
cookie, _ := req.WAFRaw().Cookie(userField)
|
||||
if cookie != nil {
|
||||
v := cookie.Value
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
@@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "get":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.URL.Query().Get(userField)
|
||||
v := req.WAFRaw().URL.Query().Get(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
@@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "post":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.PostFormValue(userField)
|
||||
v := req.WAFRaw().PostFormValue(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
@@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "header":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.Header.Get(userField)
|
||||
v := req.WAFRaw().Header.Get(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
key = "USER@" + userType + "@" + userField + "@" + v
|
||||
}
|
||||
default:
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
if len(key) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
|
||||
}
|
||||
@@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
return
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
@@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() {
|
||||
this.cache = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) ip(req *requests.Request) string {
|
||||
// X-Forwarded-For
|
||||
forwardedFor := req.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
commaIndex := strings.Index(forwardedFor, ",")
|
||||
if commaIndex > 0 {
|
||||
return forwardedFor[:commaIndex]
|
||||
}
|
||||
return forwardedFor
|
||||
}
|
||||
|
||||
// Real-IP
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-IP"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Real-Ip
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-Ip"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Remote-Addr
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
48
internal/waf/checkpoints/cc2.go
Normal file
48
internal/waf/checkpoints/cc2.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package checkpoints
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ccCache = ttlcache.NewCache(ttlcache.NewPiecesOption(32))
|
||||
|
||||
// CC2Checkpoint 新的CC
|
||||
type CC2Checkpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var keys = options.GetSlice("keys")
|
||||
var keyValues = []string{}
|
||||
for _, key := range keys {
|
||||
keyValues = append(keyValues, req.Format(types.String(key)))
|
||||
}
|
||||
if len(keyValues) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var period = options.GetInt64("period")
|
||||
if period <= 0 {
|
||||
period = 60
|
||||
}
|
||||
|
||||
var threshold = options.GetInt64("threshold")
|
||||
if threshold <= 0 {
|
||||
threshold = 1000
|
||||
}
|
||||
|
||||
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package checkpoints
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
@@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(raw)
|
||||
req.RemoteAddr = "127.0.0.1"
|
||||
req := requests.NewTestRequest(raw)
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.1"
|
||||
|
||||
checkpoint := new(CCCheckpoint)
|
||||
checkpoint.Init()
|
||||
checkpoint.Start()
|
||||
|
||||
options := map[string]string{
|
||||
options := maps.Map{
|
||||
"period": "5",
|
||||
}
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.1"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.1"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
}
|
||||
|
||||
@@ -5,32 +5,32 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// Check Point
|
||||
// CheckpointInterface Check Point
|
||||
type CheckpointInterface interface {
|
||||
// initialize
|
||||
// Init initialize
|
||||
Init()
|
||||
|
||||
// is request?
|
||||
// IsRequest is request?
|
||||
IsRequest() bool
|
||||
|
||||
// is composed?
|
||||
// IsComposed is composed?
|
||||
IsComposed() bool
|
||||
|
||||
// get request value
|
||||
RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
// RequestValue get request value
|
||||
RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
|
||||
// get response value
|
||||
ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
// ResponseValue get response value
|
||||
ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
|
||||
// param option list
|
||||
// ParamOptions param option list
|
||||
ParamOptions() *ParamOptions
|
||||
|
||||
// options
|
||||
// Options options
|
||||
Options() []OptionInterface
|
||||
|
||||
// start
|
||||
// Start start
|
||||
Start()
|
||||
|
||||
// stop
|
||||
// Stop stop
|
||||
Stop()
|
||||
}
|
||||
|
||||
@@ -5,32 +5,34 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// ${requestAll}
|
||||
// RequestAllCheckpoint ${requestAll}
|
||||
type RequestAllCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
valueBytes := []byte{}
|
||||
if len(req.RequestURI) > 0 {
|
||||
valueBytes = append(valueBytes, req.RequestURI...)
|
||||
} else if req.URL != nil {
|
||||
valueBytes = append(valueBytes, req.URL.RequestURI()...)
|
||||
if len(req.WAFRaw().RequestURI) > 0 {
|
||||
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
|
||||
} else if req.WAFRaw().URL != nil {
|
||||
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
if req.WAFRaw().Body != nil {
|
||||
valueBytes = append(valueBytes, ' ')
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
valueBytes = append(valueBytes, req.BodyData...)
|
||||
valueBytes = append(valueBytes, bodyData...)
|
||||
}
|
||||
|
||||
value = valueBytes
|
||||
@@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if sysErr != nil {
|
||||
t.Fatal(sysErr)
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
_, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,11 @@ type RequestArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.URL.Query().Get(param), nil, nil
|
||||
func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.WAFRaw().URL.Query().Get(param), nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
|
||||
checkpoint := new(RequestArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.URL.RawQuery
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().URL.RawQuery
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -5,31 +5,33 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// ${requestBody}
|
||||
// RequestBodyCheckpoint ${requestBody}
|
||||
type RequestBodyCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.Body == nil {
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.WAFRaw().Body == nil {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
|
||||
return req.BodyData, nil, nil
|
||||
return bodyData, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,19 +11,20 @@ import (
|
||||
)
|
||||
|
||||
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
|
||||
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var req = requests.NewTestRequest(rawReq)
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(rawReq.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(string(body))
|
||||
t.Log(string(req.WAFGetCacheBody()))
|
||||
}
|
||||
|
||||
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
@@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Header.Get("Content-Type")
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Header.Get("Content-Type")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
cookie, err := req.Cookie(param)
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
cookie, err := req.WAFRaw().Cookie(param)
|
||||
if err != nil {
|
||||
value = ""
|
||||
return
|
||||
@@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var cookies = []string{}
|
||||
for _, cookie := range req.Cookies() {
|
||||
for _, cookie := range req.WAFRaw().Cookies() {
|
||||
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
|
||||
}
|
||||
value = strings.Join(cookies, "&")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -6,33 +6,35 @@ import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// ${requestForm.arg}
|
||||
// RequestFormArgCheckpoint ${requestForm.arg}
|
||||
type RequestFormArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.Body == nil {
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.WAFRaw().Body == nil {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
|
||||
// TODO improve performance
|
||||
values, _ := url.ParseQuery(string(req.BodyData))
|
||||
values, _ := url.ParseQuery(string(bodyData))
|
||||
return values.Get(param), nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestFormArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
@@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "Hello", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "encoded", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = false
|
||||
|
||||
headers := options.GetSlice("headers")
|
||||
@@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
|
||||
length := options.GetInt("length")
|
||||
|
||||
for _, header := range headers {
|
||||
v := req.Header.Get(types.String(header))
|
||||
v := req.WAFRaw().Header.Get(types.String(header))
|
||||
if len(v) > length {
|
||||
value = true
|
||||
break
|
||||
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
v, found := req.Header[param]
|
||||
func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
v, found := req.WAFRaw().Header[param]
|
||||
if !found {
|
||||
value = ""
|
||||
return
|
||||
@@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var headers = []string{}
|
||||
for k, v := range req.Header {
|
||||
for k, v := range req.WAFRaw().Header {
|
||||
for _, subV := range v {
|
||||
headers = append(headers, k+": "+subV)
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestHostCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Host
|
||||
func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Host
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user