Compare commits

...

16 Commits

Author SHA1 Message Date
刘祥超
01cfccebbd 缓存结束后增加Content-Length对比 2021-07-20 19:48:08 +08:00
刘祥超
6bd7da5e6e build脚本增加arm编译 2021-07-20 19:01:49 +08:00
刘祥超
dcba9c2f3e 调整格式等 2021-07-20 19:01:37 +08:00
刘祥超
f38e80e82d 自动替换API节点时增加对新节点的测试 2021-07-20 18:17:25 +08:00
刘祥超
9bd38094c3 将WAF模板中的cc修改为cc2 2021-07-19 11:01:38 +08:00
刘祥超
7e37fc3b80 实现新的CC 2021-07-19 10:49:56 +08:00
刘祥超
d775dfeeaa WAF增加多个动作 2021-07-18 15:51:49 +08:00
刘祥超
0486f86898 增加攻击拦截统计 2021-07-13 11:04:38 +08:00
刘祥超
102157c893 节点看板增加缓存目录用量 2021-07-08 19:43:30 +08:00
刘祥超
dbd92368ae 文件缓存统计去除过期时间条件 2021-07-08 09:19:41 +08:00
刘祥超
9e418e73bf 节点状态上报流量 2021-07-07 15:18:01 +08:00
刘祥超
5d40eec163 实现节点看板(仅对企业版开放) 2021-07-06 20:06:57 +08:00
刘祥超
1a7a67238d 增加域名统计 2021-07-05 11:36:50 +08:00
刘祥超
6707437bae 指标数据增加总和数据 2021-07-01 10:39:56 +08:00
刘祥超
df7859387d 实现基础的统计指标 2021-06-30 20:01:00 +08:00
刘祥超
889c52330d 修改版本为0.2.5 2021-06-28 10:32:18 +08:00
126 changed files with 3362 additions and 962 deletions

View File

@@ -66,6 +66,10 @@ function build() {
CC_PATH="aarch64-linux-musl-gcc"
CXX_PATH="aarch64-linux-musl-g++"
fi
if [ "${ARCH}" == "arm" ]; then
CC_PATH="arm-linux-musleabi-gcc"
CXX_PATH="arm-linux-musleabi-g++"
fi
if [ "${ARCH}" == "mips64" ]; then
CC_PATH="mips64-linux-musl-gcc"
CXX_PATH="mips64-linux-musl-g++"

View File

@@ -1 +1 @@
index.*
*

4
go.mod
View File

@@ -12,10 +12,10 @@ require (
github.com/go-ole/go-ole v1.2.4 // indirect
github.com/go-yaml/yaml v2.1.0+incompatible
github.com/golang/protobuf v1.5.2
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible
github.com/mattn/go-sqlite3 v1.14.7
github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/mssola/user_agent v0.5.2
github.com/shirou/gopsutil v3.21.5+incompatible
github.com/tklauser/go-sysconf v0.3.6 // indirect

25
go.sum
View File

@@ -48,7 +48,6 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
@@ -57,27 +56,30 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f h1:r2O8PONj/KiuZjJHVHn7KlCePUIjNtgAmvLfgRafQ8o=
github.com/iwind/TeaGo v0.0.0-20210411134150-ddf57e240c2f/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060 h1:qdLtK4PDXxk2vMKkTWl5Fl9xqYuRCukzWAgJbLHdfOo=
github.com/iwind/TeaGo v0.0.0-20210628135026-38575a4ab060/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11 h1:DaQjoWZhLNxjhIXedVg4/vFEtHkZhK4IjIwsWdyzBLg=
github.com/iwind/gofcgi v0.0.0-20210528023741-a92711d45f11/go.mod h1:JtbX20untAjUVjZs1ZBtq80f5rJWvwtQNRL6EnuYRnY=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible h1:1qp9iks+69h7IGLazAplzS9Ca14HAxuD5c0rbFdPGy4=
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible/go.mod h1:+ZBN7PBoh5gG6/y0ZQ85vJDBe21WnfbRrQQwTfliJJI=
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/mssola/user_agent v0.5.2 h1:CZkTUahjL1+OcZ5zv3kZr8QiJ8jy2H08vZIEkBeRbxo=
github.com/mssola/user_agent v0.5.2/go.mod h1:TTPno8LPY3wAIEKRpAtkdMT0f8SE24pLRGPahjCH4uw=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
@@ -94,8 +96,6 @@ github.com/opentracing/opentracing-go v1.1.1-0.20190913142402-a7454ce5950e/go.mo
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/shirou/gopsutil v2.20.9+incompatible h1:msXs2frUV+O/JLva9EDLpuJ84PrFsdCTCQex8PUdtkQ=
github.com/shirou/gopsutil v2.20.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/gopsutil v3.21.5+incompatible h1:OloQyEerMi7JUrXiNzy8wQ5XN+baemxSl12QgIzt0jc=
github.com/shirou/gopsutil v3.21.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
@@ -134,7 +134,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7 h1:AeiKBIuRw3UomYXSbLy0Mc2dDLfdtbT/IVn4keq83P0=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q=
@@ -155,10 +154,8 @@ golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 h1:DYfZAGf2WMFjMxbgTjaC+2HC7NkNAQs+6Q8b9WEB/F4=
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa h1:ZYxPR6aca/uhfRJyaOAtflSHjJYiktO7QnJC5ut7iY4=
golang.org/x/sys v0.0.0-20210316164454-77fc1eacc6aa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -167,7 +164,6 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7s
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -184,15 +180,14 @@ golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20191009194640-548a555dbc03/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20210617175327-b9e0b3197ced h1:c5geK1iMU3cDKtFrCVQIcjR3W+JOZMuhIyICMCTbtus=
google.golang.org/genproto v0.0.0-20210617175327-b9e0b3197ced/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24=
@@ -201,7 +196,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0=
google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0=
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
@@ -213,7 +207,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c=
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=

View File

@@ -115,7 +115,7 @@ func (this *FileList) Init() error {
return err
}
this.statStmt, err = this.db.Prepare(`SELECT COUNT(*), IFNULL(SUM(headerSize+bodySize+metaSize), 0), IFNULL(SUM(headerSize+bodySize), 0) FROM "` + this.itemsTableName + `" WHERE expiredAt>?`)
this.statStmt, err = this.db.Prepare(`SELECT COUNT(*), IFNULL(SUM(headerSize+bodySize+metaSize), 0), IFNULL(SUM(headerSize+bodySize), 0) FROM "` + this.itemsTableName + `"`)
if err != nil {
return err
}
@@ -297,7 +297,7 @@ func (this *FileList) Stat(check func(hash string) bool) (*Stat, error) {
}
// 这里不设置过期时间、不使用 check 函数,目的是让查询更快速一些
row := this.statStmt.QueryRow(time.Now().Unix())
row := this.statStmt.QueryRow()
if row.Err() != nil {
return nil, row.Err()
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"strconv"
"sync"
)
@@ -162,3 +163,25 @@ func (this *Manager) TotalMemorySize() int64 {
}
return total
}
// FindAllCachePaths 所有缓存路径
func (this *Manager) FindAllCachePaths() []string {
this.locker.Lock()
defer this.locker.Unlock()
var result = []string{}
for _, policy := range this.policyMap {
if policy.Type == serverconfigs.CachePolicyStorageFile {
if policy.Options != nil {
dir, ok := policy.Options["dir"]
if ok {
var dirString = types.String(dir)
if len(dirString) > 0 {
result = append(result, dirString)
}
}
}
}
}
return result
}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.2.4"
Version = "0.2.5"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -6,7 +6,7 @@ import (
"sync"
)
// IP名单
// IPList IP名单
type IPList struct {
itemsMap map[int64]*IPItem // id => item
ipMap map[uint64][]int64 // ip => itemIds
@@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) {
this.isAll = len(this.ipMap[0]) > 0
}
// 判断是否包含某个IP
// Contains 判断是否包含某个IP
func (this *IPList) Contains(ip uint64) bool {
this.locker.RLock()
if this.isAll {
@@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool {
return ok
}
// 是否包含一组IP
// ContainsIPStrings 是否包含一组IP
func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
if len(ipStrings) == 0 {
return

121
internal/metrics/manager.go Normal file
View File

@@ -0,0 +1,121 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"strconv"
"sync"
)
var SharedManager = NewManager()
type Manager struct {
tasks map[int64]*Task // itemId => *Task
categoryTasks map[string][]*Task // category => []*Task
locker sync.RWMutex
hasHTTPMetrics bool
hasTCPMetrics bool
hasUDPMetrics bool
}
func NewManager() *Manager {
return &Manager{
tasks: map[int64]*Task{},
categoryTasks: map[string][]*Task{},
}
}
func (this *Manager) Update(items []*serverconfigs.MetricItemConfig) {
this.locker.Lock()
defer this.locker.Unlock()
var newMap = map[int64]*serverconfigs.MetricItemConfig{}
for _, item := range items {
newMap[item.Id] = item
}
// 停用以前的 或 修改现在的
for itemId, task := range this.tasks {
newItem, ok := newMap[itemId]
if !ok || !newItem.IsOn { // 停用以前的
remotelogs.Println("METRIC_MANAGER", "stop task '"+strconv.FormatInt(itemId, 10)+"'")
err := task.Stop()
if err != nil {
remotelogs.Error("METRIC_MANAGER", "stop task '"+strconv.FormatInt(itemId, 10)+"' failed: "+err.Error())
}
delete(this.tasks, itemId)
} else { // 更新已存在的
if newItem.Version != task.item.Version {
remotelogs.Println("METRIC_MANAGER", "update task '"+strconv.FormatInt(itemId, 10)+"'")
task.item = newItem
}
}
}
// 启动新的
for _, newItem := range items {
if !newItem.IsOn {
continue
}
_, ok := this.tasks[newItem.Id]
if !ok {
remotelogs.Println("METRIC_MANAGER", "start task '"+strconv.FormatInt(newItem.Id, 10)+"'")
task := NewTask(newItem)
err := task.Init()
if err != nil {
remotelogs.Error("METRIC_MANAGER", "initialized task failed: "+err.Error())
continue
}
err = task.Start()
if err != nil {
remotelogs.Error("METRIC_MANAGER", "start task failed: "+err.Error())
continue
}
this.tasks[newItem.Id] = task
}
}
// 按分类存放
this.hasHTTPMetrics = false
this.hasTCPMetrics = false
this.hasUDPMetrics = false
this.categoryTasks = map[string][]*Task{}
for _, task := range this.tasks {
tasks := this.categoryTasks[task.item.Category]
tasks = append(tasks, task)
this.categoryTasks[task.item.Category] = tasks
switch task.item.Category {
case serverconfigs.MetricItemCategoryHTTP:
this.hasHTTPMetrics = true
case serverconfigs.MetricItemCategoryTCP:
this.hasTCPMetrics = true
case serverconfigs.MetricItemCategoryUDP:
this.hasUDPMetrics = true
}
}
}
// Add 添加数据
func (this *Manager) Add(obj MetricInterface) {
this.locker.RLock()
for _, task := range this.categoryTasks[obj.MetricCategory()] {
task.Add(obj)
}
this.locker.RUnlock()
}
func (this *Manager) HasHTTPMetrics() bool {
return this.hasHTTPMetrics
}
func (this *Manager) HasTCPMetrics() bool {
return this.hasTCPMetrics
}
func (this *Manager) HasUDPMetrics() bool {
return this.hasUDPMetrics
}

View File

@@ -0,0 +1,63 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"testing"
)
func TestNewManager(t *testing.T) {
var manager = NewManager()
{
manager.Update([]*serverconfigs.MetricItemConfig{})
for _, task := range manager.tasks {
t.Log(task.item.Id)
}
}
{
t.Log("====")
manager.Update([]*serverconfigs.MetricItemConfig{
{
Id: 1,
},
{
Id: 2,
},
{
Id: 3,
},
})
for _, task := range manager.tasks {
t.Log("task:", task.item.Id)
}
}
{
t.Log("====")
manager.Update([]*serverconfigs.MetricItemConfig{
{
Id: 1,
},
{
Id: 2,
},
})
for _, task := range manager.tasks {
t.Log("task:", task.item.Id)
}
}
{
t.Log("====")
manager.Update([]*serverconfigs.MetricItemConfig{
{
Id: 1,
Version: 1,
},
})
for _, task := range manager.tasks {
t.Log("task:", task.item.Id)
}
}
}

View File

@@ -0,0 +1,17 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics
type MetricInterface interface {
// MetricKey 指标对象
MetricKey(key string) string
// MetricValue 指标值
MetricValue(value string) (result int64, ok bool)
// MetricServerId 服务ID
MetricServerId() int64
// MetricCategory 指标分类
MetricCategory() string
}

22
internal/metrics/stat.go Normal file
View File

@@ -0,0 +1,22 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics
import (
"github.com/cespare/xxhash"
"strconv"
)
type Stat struct {
ServerId int64
Keys []string
Hash string
Value int64
Time string
keysData []byte
}
func (this *Stat) Sum(version int32, itemId int64) {
this.Hash = strconv.FormatUint(xxhash.Sum64String(strconv.FormatInt(this.ServerId, 10)+"@"+string(this.keysData)+"@"+this.Time+"@"+strconv.Itoa(int(version))+"@"+strconv.FormatInt(itemId, 10)), 10)
}

492
internal/metrics/task.go Normal file
View File

@@ -0,0 +1,492 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics
import (
"database/sql"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/mattn/go-sqlite3"
"os"
"strconv"
"strings"
"sync"
"time"
)
// Task 单个指标任务
// 数据库存储:
// data/
// metric.$ID.db
// stats
// id, keys, value, time, serverId, hash
// 原理:
// 添加或者有变更时 isUploaded = false
// 上传时检查 isUploaded 状态
// 只上传每个服务中排序最前面的 N 个数据
type Task struct {
item *serverconfigs.MetricItemConfig
isLoaded bool
db *sql.DB
statTableName string
statsChan chan *Stat
isStopped bool
cleanTicker *utils.Ticker
uploadTicker *utils.Ticker
cleanVersion int32
insertStatStmt *sql.Stmt
deleteByVersionStmt *sql.Stmt
deleteByExpiresTimeStmt *sql.Stmt
selectTopStmt *sql.Stmt
sumStmt *sql.Stmt
serverIdMap map[int64]bool // 所有的服务Ids
timeMap map[string]bool // time => bool
serverIdMapLocker sync.Mutex
}
// NewTask 获取新任务
func NewTask(item *serverconfigs.MetricItemConfig) *Task {
return &Task{
item: item,
statsChan: make(chan *Stat, 40960),
serverIdMap: map[int64]bool{},
timeMap: map[string]bool{},
}
}
// Init 初始化
func (this *Task) Init() error {
this.statTableName = "stats"
// 检查目录是否存在
var dir = Tea.Root + "/data"
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
remotelogs.Println("METRIC", "create data dir '"+dir+"'")
}
db, err := sql.Open("sqlite3", "file:"+dir+"/metric."+strconv.FormatInt(this.item.Id, 10)+".db?cache=shared&mode=rwc&_journal_mode=WAL")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
this.db = db
//创建统计表
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.statTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"hash" varchar(32),
"keys" varchar(1024),
"value" real DEFAULT 0,
"time" varchar(32),
"serverId" integer DEFAULT 0,
"version" integer DEFAULT 0,
"isUploaded" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "serverId"
ON "` + this.statTableName + `" (
"serverId" ASC,
"version" ASC
);
CREATE UNIQUE INDEX IF NOT EXISTS "hash"
ON "` + this.statTableName + `" (
"hash" ASC
);`)
if err != nil {
return err
}
// insert stat stmt
this.insertStatStmt, err = db.Prepare(`INSERT INTO "stats" ("serverId", "hash", "keys", "value", "time", "version", "isUploaded") VALUES (?, ?, ?, ?, ?, ?, 0) ON CONFLICT("hash") DO UPDATE SET "value"="value"+?, "isUploaded"=0`)
if err != nil {
return err
}
// delete by version
this.deleteByVersionStmt, err = db.Prepare(`DELETE FROM "` + this.statTableName + `" WHERE "version"<?`)
if err != nil {
return err
}
// delete by expires time
this.deleteByExpiresTimeStmt, err = db.Prepare(`DELETE FROM "` + this.statTableName + `" WHERE "time"<?`)
if err != nil {
return err
}
// select topN stmt
this.selectTopStmt, err = db.Prepare(`SELECT "id", "hash", "keys", "value", "isUploaded" FROM "` + this.statTableName + `" WHERE "serverId"=? AND "version"=? AND time=? ORDER BY "value" DESC LIMIT 100`)
if err != nil {
return err
}
// sum stmt
this.sumStmt, err = db.Prepare(`SELECT COUNT(*), IFNULL(SUM(value), 0) FROM "` + this.statTableName + `" WHERE "serverId"=? AND "version"=? AND time=?`)
if err != nil {
return err
}
// 所有的服务IDs
err = this.loadServerIdMap()
if err != nil {
return err
}
this.isLoaded = true
return nil
}
// Start 启动任务
func (this *Task) Start() error {
// 读取数据
go func() {
for stat := range this.statsChan {
if stat == nil {
return
}
err := this.InsertStat(stat)
if err != nil {
remotelogs.Error("METRIC", "insert stat failed: "+err.Error())
}
}
}()
// 清理
this.cleanTicker = utils.NewTicker(24 * time.Hour)
go func() {
for this.cleanTicker.Next() {
err := this.CleanExpired()
if err != nil {
remotelogs.Error("METRIC", "clean expired stats failed: "+err.Error())
}
}
}()
// 上传
this.uploadTicker = utils.NewTicker(this.item.UploadDuration())
go func() {
for this.uploadTicker.Next() {
err := this.Upload(1 * time.Second)
if err != nil {
remotelogs.Error("METRIC", "upload stats failed: "+err.Error())
}
}
}()
return nil
}
// Add 添加数据
func (this *Task) Add(obj MetricInterface) {
if this.isStopped || !this.isLoaded {
return
}
var keys = []string{}
for _, key := range this.item.Keys {
k := obj.MetricKey(key)
keys = append(keys, k)
}
v, ok := obj.MetricValue(this.item.Value)
if !ok {
return
}
var stat = &Stat{
ServerId: obj.MetricServerId(),
Keys: keys,
Value: v,
Time: this.item.CurrentTime(),
}
select {
case this.statsChan <- stat:
default:
// 丢弃
}
}
// Stop 停止任务
func (this *Task) Stop() error {
this.isStopped = true
if this.cleanTicker != nil {
this.cleanTicker.Stop()
}
if this.uploadTicker != nil {
this.uploadTicker.Stop()
}
_ = this.insertStatStmt.Close()
_ = this.deleteByVersionStmt.Close()
_ = this.deleteByExpiresTimeStmt.Close()
_ = this.selectTopStmt.Close()
_ = this.sumStmt.Close()
if this.statsChan != nil {
go func() {
// 延时关闭,防止关闭时写入
time.Sleep(5 * time.Second)
close(this.statsChan)
}()
}
if this.db != nil {
_ = this.db.Close()
}
return nil
}
// InsertStat 写入数据
func (this *Task) InsertStat(stat *Stat) error {
if this.isStopped {
return nil
}
if stat == nil {
return nil
}
this.serverIdMapLocker.Lock()
this.serverIdMap[stat.ServerId] = true
this.timeMap[stat.Time] = true
this.serverIdMapLocker.Unlock()
keyData, err := json.Marshal(stat.Keys)
if err != nil {
return err
}
stat.keysData = keyData
stat.Sum(this.item.Version, this.item.Id)
_, err = this.insertStatStmt.Exec(stat.ServerId, stat.Hash, stat.keysData, stat.Value, stat.Time, this.item.Version, stat.Value)
if err != nil {
return err
}
return nil
}
// CleanExpired 清理数据
func (this *Task) CleanExpired() error {
if this.isStopped {
return nil
}
// 清除低版本数据
if this.cleanVersion < this.item.Version {
_, err := this.deleteByVersionStmt.Exec(this.item.Version)
if err != nil {
return err
}
this.cleanVersion = this.item.Version
}
// 清除过期的数据
_, err := this.deleteByExpiresTimeStmt.Exec(this.item.LocalExpiresTime())
if err != nil {
return err
}
return nil
}
// Upload 上传数据
func (this *Task) Upload(pauseDuration time.Duration) error {
if this.isStopped {
return nil
}
this.serverIdMapLocker.Lock()
// 服务IDs
var serverIds []int64
for serverId := range this.serverIdMap {
serverIds = append(serverIds, serverId)
}
this.serverIdMap = map[int64]bool{} // 清空数据
// 时间
var times = []string{}
for t := range this.timeMap {
times = append(times, t)
}
this.timeMap = map[string]bool{} // 清空数据
this.serverIdMapLocker.Unlock()
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
for _, serverId := range serverIds {
for _, currentTime := range times {
idStrings, err := func(serverId int64, currentTime string) (ids []string, err error) {
rows, err := this.selectTopStmt.Query(serverId, this.item.Version, currentTime)
if err != nil {
return nil, err
}
var isClosed bool
defer func() {
if isClosed {
return
}
_ = rows.Close()
}()
var pbStats []*pb.UploadingMetricStat
for rows.Next() {
var pbStat = &pb.UploadingMetricStat{
}
// "id", "hash", "keys", "value", "isUploaded"
var isUploaded int
var keysData []byte
err = rows.Scan(&pbStat.Id, &pbStat.Hash, &keysData, &pbStat.Value, &isUploaded)
if err != nil {
return nil, err
}
if isUploaded == 1 {
continue
}
if len(keysData) > 0 {
err = json.Unmarshal(keysData, &pbStat.Keys)
if err != nil {
return nil, err
}
}
pbStats = append(pbStats, pbStat)
ids = append(ids, strconv.FormatInt(pbStat.Id, 10))
}
// 提前关闭
_ = rows.Close()
isClosed = true
// 上传
if len(pbStats) > 0 {
// 计算总和
count, total, err := this.sum(serverId, currentTime)
if err != nil {
return nil, err
}
_, err = rpcClient.MetricStatRPC().UploadMetricStats(rpcClient.Context(), &pb.UploadMetricStatsRequest{
MetricStats: pbStats,
Time: currentTime,
ServerId: serverId,
ItemId: this.item.Id,
Version: this.item.Version,
Count: count,
Total: float32(total),
})
if err != nil {
return nil, err
}
}
return
}(serverId, currentTime)
if err != nil {
return err
}
if len(idStrings) > 0 {
// 设置为已上传
_, err = this.db.Exec(`UPDATE "` + this.statTableName + `" SET isUploaded=1 WHERE id IN (` + strings.Join(idStrings, ",") + `)`)
if err != nil {
return err
}
}
}
// 休息一下,防止短时间内上传数据过多
if pauseDuration > 0 {
time.Sleep(pauseDuration)
}
}
return nil
}
// 加载服务ID
func (this *Task) loadServerIdMap() error {
{
rows, err := this.db.Query(`SELECT DISTINCT "serverId" FROM `+this.statTableName+" WHERE version=?", this.item.Version)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
var serverId int64
for rows.Next() {
err = rows.Scan(&serverId)
if err != nil {
return err
}
this.serverIdMapLocker.Lock()
this.serverIdMap[serverId] = true
this.serverIdMapLocker.Unlock()
}
}
{
rows, err := this.db.Query(`SELECT DISTINCT "time" FROM `+this.statTableName+" WHERE version=?", this.item.Version)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
var timeString string
for rows.Next() {
err = rows.Scan(&timeString)
if err != nil {
return err
}
this.serverIdMapLocker.Lock()
this.timeMap[timeString] = true
this.serverIdMapLocker.Unlock()
}
}
return nil
}
// 计算数量和综合
func (this *Task) sum(serverId int64, time string) (count int64, total float64, err error) {
rows, err := this.sumStmt.Query(serverId, this.item.Version, time)
if err != nil {
return 0, 0, err
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
err = rows.Scan(&count, &total)
if err != nil {
return 0, 0, err
}
}
return
}

View File

@@ -0,0 +1,210 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package metrics_test
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/rands"
"testing"
"time"
)
type testObj struct {
ip string
}
func (this *testObj) MetricKey(key string) string {
return this.ip
}
func (this *testObj) MetricValue(value string) (int64, bool) {
return 1, true
}
func (this *testObj) MetricServerId() int64 {
return int64(rands.Int(1, 100))
}
func (this *testObj) MetricCategory() string {
return "http"
}
func TestTask_Init(t *testing.T) {
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 0,
PeriodUnit: "",
Keys: nil,
Value: "",
})
err := task.Init()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
t.Log("ok")
}
func TestTask_Add(t *testing.T) {
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 1,
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
Keys: []string{"${remoteAddr}"},
Value: "${countRequest}",
})
err := task.Init()
if err != nil {
t.Fatal(err)
}
err = task.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
task.Add(&testObj{ip: "127.0.0.2"})
time.Sleep(1 * time.Second) // waiting for inserting
}
func TestTask_Add_Many(t *testing.T) {
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 1,
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
Keys: []string{"${remoteAddr}"},
Value: "${countRequest}",
Version: 1,
})
err := task.Init()
if err != nil {
t.Fatal(err)
}
err = task.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
for i := 0; i < 4_000_000; i++ {
task.Add(&testObj{
ip: fmt.Sprintf("%d.%d.%d.%d", rands.Int(0, 255), rands.Int(0, 255), rands.Int(0, 255), rands.Int(0, 255)),
})
if i%10000 == 0 {
time.Sleep(1 * time.Second)
}
}
}
func TestTask_InsertStat(t *testing.T) {
var item = &serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 1,
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
Keys: []string{"${remoteAddr}"},
Value: "${countRequest}",
Version: 1,
}
var task = metrics.NewTask(item)
err := task.Init()
if err != nil {
t.Fatal(err)
}
err = task.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
err = task.InsertStat(&metrics.Stat{
ServerId: 1,
Keys: []string{"127.0.0.1"},
Hash: "",
Value: 1,
Time: item.CurrentTime(),
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTask_CleanExpired(t *testing.T) {
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 1,
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
Keys: []string{"${remoteAddr}"},
Value: "${countRequest}",
Version: 1,
})
err := task.Init()
if err != nil {
t.Fatal(err)
}
err = task.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
err = task.CleanExpired()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTask_Upload(t *testing.T) {
var task = metrics.NewTask(&serverconfigs.MetricItemConfig{
Id: 1,
IsOn: false,
Category: "",
Period: 1,
PeriodUnit: serverconfigs.MetricItemPeriodUnitDay,
Keys: []string{"${remoteAddr}"},
Value: "${countRequest}",
Version: 1,
})
err := task.Init()
if err != nil {
t.Fatal(err)
}
err = task.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = task.Stop()
}()
err = task.Upload(0)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -72,7 +72,8 @@ func (this *ValueQueue) Loop() error {
CreatedAt: value.CreatedAt,
})
if err != nil {
return err
remotelogs.Error("MONITOR", err.Error())
continue
}
}
return nil

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
@@ -66,12 +67,16 @@ type HTTPRequest struct {
cacheRef *serverconfigs.HTTPCacheRef // 缓存设置
cacheKey string // 缓存使用的Key
isCached bool // 是否已经被缓存
isAttack bool // 是否是攻击请求
bodyData []byte // 读取的Body内容
// WAF相关
firewallPolicyId int64
firewallRuleGroupId int64
firewallRuleSetId int64
firewallRuleId int64
firewallActions []string
tags []string
logAttrs map[string]string
@@ -242,11 +247,20 @@ func (this *HTTPRequest) doEnd() {
// TODO 增加是否开启开关
if this.Server != nil {
if this.isCached {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1)
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0)
} else {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.writer.sentBodyBytes, 0, 1, 0)
if this.isAttack {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes)
} else {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0)
}
}
}
// 指标
if metrics.SharedManager.HasHTTPMetrics() {
this.doMetricsResponse()
}
}
// RawURI 原始的请求URI
@@ -500,6 +514,12 @@ func (this *HTTPRequest) Format(source string) string {
return this.requestRemoteUser()
case "requestURI", "requestUri":
return this.rawURI
case "requestURL":
var scheme = "http"
if this.IsHTTPS {
scheme = "https"
}
return scheme + "://" + this.Host + this.rawURI
case "requestPath":
return this.requestPath()
case "requestPathExtension":
@@ -1186,5 +1206,10 @@ func (this *HTTPRequest) canIgnore(err error) bool {
return true
}
// HTTP内部错误
if strings.HasPrefix(err.Error(), "http:") || strings.HasPrefix(err.Error(), "http2:") {
return true
}
return false
}

View File

@@ -128,6 +128,8 @@ func (this *HTTPRequest) log() {
FirewallRuleGroupId: this.firewallRuleGroupId,
FirewallRuleSetId: this.firewallRuleSetId,
FirewallRuleId: this.firewallRuleId,
FirewallActions: this.firewallActions,
Tags: this.tags,
Attrs: this.logAttrs,
}

View File

@@ -0,0 +1,58 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
)
// 指标统计 - 响应
// 只需要在结束时调用指标进行统计
func (this *HTTPRequest) doMetricsResponse() {
metrics.SharedManager.Add(this)
}
func (this *HTTPRequest) MetricKey(key string) string {
return this.Format(key)
}
func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
// TODO 需要忽略健康检查的请求,但是同时也要防止攻击者模拟健康检查
switch value {
case "${countRequest}":
return 1, true
case "${countTrafficOut}":
// 这里不包括Header长度
return this.writer.SentBodyBytes(), true
case "${countTrafficIn}":
var hl int64 = 0 // header length
for k, values := range this.RawReq.Header {
for _, v := range values {
hl += int64(len(k) + len(v) + 2 /** k: v **/)
}
}
return this.RawReq.ContentLength + hl, true
case "${countConnection}":
metricNewConnMapLocker.Lock()
_, ok := metricNewConnMap[this.RawReq.RemoteAddr]
if ok {
delete(metricNewConnMap, this.RawReq.RemoteAddr)
}
metricNewConnMapLocker.Unlock()
if ok {
return 1, true
} else {
return 0, false
}
}
return 0, false
}
func (this *HTTPRequest) MetricServerId() int64 {
return this.Server.Id
}
func (this *HTTPRequest) MetricCategory() string {
return serverconfigs.MetricItemCategoryHTTP
}

View File

@@ -7,6 +7,8 @@ func (this *HTTPRequest) doStat() {
if this.Server == nil {
return
}
// 内置的统计
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.Server.Id, this.requestRemoteAddr())
stats.SharedHTTPRequestStatManager.AddUserAgent(this.Server.Id, this.requestHeader("User-Agent"))
}

View File

@@ -1,6 +1,7 @@
package nodes
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -8,6 +9,8 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"io"
"io/ioutil"
"net/http"
)
@@ -152,23 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if w == nil {
return
}
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
switch action.Code() {
case waf.ActionTag:
this.tags = action.(*waf.TagAction).Tags
}
return true
})
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
if err != nil {
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
return
}
if ruleSet != nil {
if ruleSet.Action != waf.ActionAllow {
if ruleSet.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
if ruleSet.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
this.logAttrs["waf.action"] = ruleSet.Action
this.firewallActions = ruleSet.ActionCodes()
}
return !goNext, false
@@ -204,24 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
switch action.Code() {
case waf.ActionTag:
this.tags = action.(*waf.TagAction).Tags
}
return true
})
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
if err != nil {
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
return
}
if ruleSet != nil {
if ruleSet.Action != waf.ActionAllow {
if ruleSet.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
if ruleSet.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
this.logAttrs["waf.action"] = ruleSet.Action
this.firewallActions = ruleSet.ActionCodes()
}
return !goNext
}
// WAFRaw 原始请求
func (this *HTTPRequest) WAFRaw() *http.Request {
return this.RawReq
}
// WAFRemoteIP 客户端IP
func (this *HTTPRequest) WAFRemoteIP() string {
return this.requestRemoteAddr()
}
// WAFGetCacheBody 获取缓存中的Body
func (this *HTTPRequest) WAFGetCacheBody() []byte {
return this.bodyData
}
// WAFSetCacheBody 设置Body
func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
this.bodyData = body
}
// WAFReadBody 读取Body
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
if this.RawReq.ContentLength > 0 {
data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
}
return
}
// WAFRestoreBody 恢复Body
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
if len(data) > 0 {
rawReader := bytes.NewBuffer(data)
buf := make([]byte, 1024)
_, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf)
this.RawReq.Body = ioutil.NopCloser(rawReader)
}
}
// WAFServerId 服务ID
func (this *HTTPRequest) WAFServerId() int64 {
return this.Server.Id
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"strings"
@@ -223,17 +224,29 @@ func (this *HTTPWriter) Close() {
// cache writer
if this.cacheWriter != nil {
if this.isOk {
err := this.cacheWriter.Close()
if err == nil {
this.cacheStorage.AddToList(&caches.Item{
Type: this.cacheWriter.ItemType(),
Key: this.cacheWriter.Key(),
ExpiredAt: this.cacheWriter.ExpiredAt(),
HeaderSize: this.cacheWriter.HeaderSize(),
BodySize: this.cacheWriter.BodySize(),
Host: this.req.Host,
ServerId: this.req.Server.Id,
})
// 对比Content-Length
contentLengthString := this.Header().Get("Content-Length")
if len(contentLengthString) > 0 {
contentLength := types.Int64(contentLengthString)
if contentLength != this.cacheWriter.BodySize() {
this.isOk = false
_ = this.cacheWriter.Discard()
}
}
if this.isOk {
err := this.cacheWriter.Close()
if err == nil {
this.cacheStorage.AddToList(&caches.Item{
Type: this.cacheWriter.ItemType(),
Key: this.cacheWriter.Key(),
ExpiredAt: this.cacheWriter.ExpiredAt(),
HeaderSize: this.cacheWriter.HeaderSize(),
BodySize: this.cacheWriter.BodySize(),
Host: this.req.Host,
ServerId: this.req.Server.Id,
})
}
}
} else {
_ = this.cacheWriter.Discard()

View File

@@ -7,7 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
http2 "golang.org/x/net/http2"
"golang.org/x/net/http2"
"sync"
)

View File

@@ -9,11 +9,14 @@ import (
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
)
var httpErrorLogger = log.New(io.Discard, "", 0)
var metricNewConnMap = map[string]bool{} // remoteAddr => bool
var metricNewConnMapLocker = &sync.Mutex{}
type HTTPListener struct {
BaseListener
@@ -39,15 +42,27 @@ func (this *HTTPListener) Serve() error {
this.httpServer = &http.Server{
Addr: this.addr,
Handler: handler,
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
ReadHeaderTimeout: 2 * time.Second, // TODO 改成可以配置
IdleTimeout: 2 * time.Minute, // TODO 改成可以配置
ErrorLog: httpErrorLogger,
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
atomic.AddInt64(&this.countActiveConnections, 1)
// 为指标存储连接信息
if sharedNodeConfig.HasHTTPConnectionMetrics() {
metricNewConnMapLocker.Lock()
metricNewConnMap[conn.RemoteAddr().String()] = true
metricNewConnMapLocker.Unlock()
}
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
// 移除指标存储连接信息
metricNewConnMapLocker.Lock()
delete(metricNewConnMap, conn.RemoteAddr().String())
metricNewConnMapLocker.Unlock()
}
},
}

View File

@@ -80,7 +80,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// 记录流量
stats.SharedTrafficStatManager.Add(firstServer.Id, int64(n), 0, 0, 0)
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0)
}
if err != nil {
closer()

View File

@@ -164,7 +164,7 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon
}
// 记录流量
stats.SharedTrafficStatManager.Add(serverId, int64(n), 0, 0, 0)
stats.SharedTrafficStatManager.Add(serverId, "", int64(n), 0, 0, 0, 0, 0)
}
if err != nil {
conn.isOk = false

View File

@@ -12,6 +12,7 @@ import (
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/stats"
@@ -396,6 +397,8 @@ func (this *Node) syncConfig() error {
iplibrary.SharedActionManager.UpdateActions(nodeConfig.FirewallActions)
sharedNodeConfig = nodeConfig
metrics.SharedManager.Update(nodeConfig.MetricItems)
// 发送事件
events.Notify(events.EventReload)

View File

@@ -15,6 +15,7 @@ import (
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/disk"
"golang.org/x/sys/unix"
"os"
"runtime"
"strings"
@@ -67,6 +68,8 @@ func (this *NodeStatusExecutor) update() {
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
status.TrafficInBytes = inTrafficBytes
status.TrafficOutBytes = outTrafficBytes
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemConnections, maps.Map{
@@ -80,6 +83,7 @@ func (this *NodeStatusExecutor) update() {
this.updateMem(status)
this.updateLoad(status)
this.updateDisk(status)
this.updateCacheSpace(status)
status.UpdatedAt = time.Now().Unix()
// 发送数据
@@ -211,3 +215,25 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
"maxUsage": status.DiskMaxUsage,
})
}
// 缓存空间
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
var result = []maps.Map{}
cachePaths := caches.SharedManager.FindAllCachePaths()
for _, path := range cachePaths {
var stat unix.Statfs_t
err := unix.Statfs(path, &stat)
if err != nil {
return
}
result = append(result, maps.Map{
"path": path,
"total": stat.Blocks * uint64(stat.Bsize),
"avail": stat.Bavail * uint64(stat.Bsize),
"used": (stat.Blocks - stat.Bavail) * uint64(stat.Bsize),
})
}
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCacheDir, maps.Map{
"dirs": result,
})
}

View File

@@ -1,6 +1,8 @@
package nodes
import (
"context"
"crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/configs"
"github.com/TeaOSLab/EdgeNode/internal/events"
@@ -8,8 +10,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"net/url"
"sort"
"strings"
"sync"
"time"
)
@@ -20,7 +26,7 @@ func init() {
})
}
// API节点同步任务
// SyncAPINodesTask API节点同步任务
type SyncAPINodesTask struct {
}
@@ -74,6 +80,12 @@ func (this *SyncAPINodesTask) Loop() error {
return nil
}
// 测试是否有API节点可用
hasOk := this.testEndpoints(newEndpoints)
if !hasOk {
return nil
}
// 修改RPC对象配置
config.RPC.Endpoints = newEndpoints
err = rpcClient.UpdateConfig(config)
@@ -95,3 +107,47 @@ func (this *SyncAPINodesTask) isSame(endpoints1 []string, endpoints2 []string) b
sort.Strings(endpoints2)
return strings.Join(endpoints1, "&") == strings.Join(endpoints2, "&")
}
func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
if len(endpoints) == 0 {
return false
}
var wg = sync.WaitGroup{}
wg.Add(len(endpoints))
var ok = false
for _, endpoint := range endpoints {
go func(endpoint string) {
defer wg.Done()
u, err := url.Parse(endpoint)
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancel()
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithInsecure(), grpc.WithBlock())
} else if u.Scheme == "https" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), grpc.WithBlock())
}
if err != nil {
return
}
_ = conn.Close()
ok = true
}(endpoint)
}
wg.Wait()
return ok
}

View File

@@ -2,7 +2,10 @@
package nodes
import "net"
import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"net"
)
// TrafficListener 用于统计流量的网络监听
type TrafficListener struct {
@@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) {
go func() {
_ = conn.Close()
}()
return conn, nil
}
}
return NewTrafficConn(conn), nil
}

View File

@@ -11,20 +11,20 @@ import (
var sharedWAFManager = NewWAFManager()
// WAF管理器
// WAFManager WAF管理器
type WAFManager struct {
mapping map[int64]*waf.WAF // policyId => WAF
locker sync.RWMutex
}
// 获取新对象
// NewWAFManager 获取新对象
func NewWAFManager() *WAFManager {
return &WAFManager{
mapping: map[int64]*waf.WAF{},
}
}
// 更新策略
// UpdatePolicies 更新策略
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
this.locker.Lock()
defer this.locker.Unlock()
@@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
this.mapping = m
}
// 查找WAF
// FindWAF 查找WAF
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
this.locker.RLock()
w, _ := this.mapping[policyId]
@@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
Action: set.Action,
ActionOptions: set.ActionOptions,
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
}
for _, a := range set.Actions {
s.AddAction(a.Code, a.Options)
}
// rules
@@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
Action: set.Action,
ActionOptions: set.ActionOptions,
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
}
for _, a := range set.Actions {
s.AddAction(a.Code, a.Options)
}
// rules
@@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// action
if policy.BlockOptions != nil {
w.ActionBlock = &waf.BlockAction{
w.DefaultBlockAction = &waf.BlockAction{
StatusCode: policy.BlockOptions.StatusCode,
Body: policy.BlockOptions.Body,
URL: "",
URL: policy.BlockOptions.URL,
Timeout: policy.BlockOptions.Timeout,
}
}

View File

@@ -109,6 +109,14 @@ func (this *RPCClient) ServerDailyStatRPC() pb.ServerDailyStatServiceClient {
return pb.NewServerDailyStatServiceClient(this.pickConn())
}
func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient {
return pb.NewMetricStatServiceClient(this.pickConn())
}
func (this *RPCClient) FirewallService() pb.FirewallServiceClient {
return pb.NewFirewallServiceClient(this.pickConn())
}
// Context 节点上下文信息
func (this *RPCClient) Context() context.Context {
ctx := context.Background()

View File

@@ -1,12 +1,16 @@
package stats
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"github.com/mssola/user_agent"
@@ -30,6 +34,8 @@ type HTTPRequestStatManager struct {
browserMap map[string]int64 // serverId@browser@version => count
dailyFirewallRuleGroupMap map[string]int64 // serverId@firewallRuleGroupId@action => count
totalAttackRequests int64
}
// NewHTTPRequestStatManager 获取新对象
@@ -38,16 +44,29 @@ func NewHTTPRequestStatManager() *HTTPRequestStatManager {
ipChan: make(chan string, 10_000), // TODO 将来可以配置容量
userAgentChan: make(chan string, 10_000), // TODO 将来可以配置容量
firewallRuleGroupChan: make(chan string, 10_000), // TODO 将来可以配置容量
cityMap: map[string]int64{},
providerMap: map[string]int64{},
systemMap: map[string]int64{},
browserMap: map[string]int64{},
cityMap: map[string]int64{},
providerMap: map[string]int64{},
systemMap: map[string]int64{},
browserMap: map[string]int64{},
dailyFirewallRuleGroupMap: map[string]int64{},
}
}
// Start 启动
func (this *HTTPRequestStatManager) Start() {
// 上传请求总数
go func() {
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
if this.totalAttackRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAttackRequests, maps.Map{"total": this.totalAttackRequests})
this.totalAttackRequests = 0
}
}
}()
}()
loopTicker := time.NewTicker(1 * time.Second)
uploadTicker := time.NewTicker(30 * time.Minute)
if Tea.IsTesting() {
@@ -114,14 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin
}
// AddFirewallRuleGroupId 添加防火墙拦截动作
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) {
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) {
if firewallRuleGroupId <= 0 {
return
}
select {
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action:
default:
// 超出容量我们就丢弃
this.totalAttackRequests++
for _, action := range actions {
select {
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code:
default:
// 超出容量我们就丢弃
}
}
}

View File

@@ -4,11 +4,15 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"strconv"
"strings"
"sync"
"time"
)
@@ -20,19 +24,25 @@ type TrafficItem struct {
CachedBytes int64
CountRequests int64
CountCachedRequests int64
CountAttackRequests int64
AttackBytes int64
}
// TrafficStatManager 区域流量统计
type TrafficStatManager struct {
itemMap map[string]*TrafficItem // [timestamp serverId] => bytes
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
locker sync.Mutex
configFunc func() *nodeconfigs.NodeConfig
totalRequests int64
}
// NewTrafficStatManager 获取新对象
func NewTrafficStatManager() *TrafficStatManager {
manager := &TrafficStatManager{
itemMap: map[string]*TrafficItem{},
itemMap: map[string]*TrafficItem{},
domainsMap: map[string]*TrafficItem{},
}
return manager
@@ -42,6 +52,20 @@ func NewTrafficStatManager() *TrafficStatManager {
func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) {
this.configFunc = configFunc
// 上传请求总数
go func() {
ticker := time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
if this.totalRequests > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemRequests, maps.Map{"total": this.totalRequests})
this.totalRequests = 0
}
}
}()
}()
// 上传统计数据
duration := 5 * time.Minute
if Tea.IsTesting() {
// 测试环境缩短上传时间,方便我们调试
@@ -62,15 +86,19 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
}
// Add 添加流量
func (this *TrafficStatManager) Add(serverId int64, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64) {
func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) {
if bytes == 0 {
return
}
this.totalRequests++
timestamp := utils.UnixTime() / 300 * 300
key := strconv.FormatInt(timestamp, 10) + strconv.FormatInt(serverId, 10)
this.locker.Lock()
// 总的流量
item, ok := this.itemMap[key]
if !ok {
item = &TrafficItem{}
@@ -80,6 +108,23 @@ func (this *TrafficStatManager) Add(serverId int64, bytes int64, cachedBytes int
item.CachedBytes += cachedBytes
item.CountRequests += countRequests
item.CountCachedRequests += countCachedRequests
item.CountAttackRequests += countAttacks
item.AttackBytes += attackBytes
// 单个域名流量
var domainKey = strconv.FormatInt(timestamp, 10) + "@" + strconv.FormatInt(serverId, 10) + "@" + domain
domainItem, ok := this.domainsMap[domainKey]
if !ok {
domainItem = &TrafficItem{}
this.domainsMap[domainKey] = domainItem
}
domainItem.Bytes += bytes
domainItem.CachedBytes += cachedBytes
domainItem.CountRequests += countRequests
domainItem.CountCachedRequests += countCachedRequests
domainItem.CountAttackRequests += countAttacks
domainItem.AttackBytes += attackBytes
this.locker.Unlock()
}
@@ -96,12 +141,15 @@ func (this *TrafficStatManager) Upload() error {
}
this.locker.Lock()
m := this.itemMap
itemMap := this.itemMap
domainMap := this.domainsMap
this.itemMap = map[string]*TrafficItem{}
this.domainsMap = map[string]*TrafficItem{}
this.locker.Unlock()
pbStats := []*pb.ServerDailyStat{}
for key, item := range m {
// 服务统计
var pbServerStats = []*pb.ServerDailyStat{}
for key, item := range itemMap {
timestamp, err := strconv.ParseInt(key[:10], 10, 64)
if err != nil {
return err
@@ -111,19 +159,45 @@ func (this *TrafficStatManager) Upload() error {
return err
}
pbStats = append(pbStats, &pb.ServerDailyStat{
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
ServerId: serverId,
RegionId: config.RegionId,
Bytes: item.Bytes,
CachedBytes: item.CachedBytes,
CountRequests: item.CountRequests,
CountCachedRequests: item.CountCachedRequests,
CountAttackRequests: item.CountAttackRequests,
AttackBytes: item.AttackBytes,
CreatedAt: timestamp,
})
}
if len(pbStats) == 0 {
if len(pbServerStats) == 0 {
return nil
}
_, err = client.ServerDailyStatRPC().UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{Stats: pbStats})
// 域名统计
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
for key, item := range domainMap {
var pieces = strings.SplitN(key, "@", 3)
if len(pieces) != 3 {
continue
}
pbDomainStats = append(pbDomainStats, &pb.UploadServerDailyStatsRequest_DomainStat{
ServerId: types.Int64(pieces[1]),
Domain: pieces[2],
Bytes: item.Bytes,
CachedBytes: item.CachedBytes,
CountRequests: item.CountRequests,
CountCachedRequests: item.CountCachedRequests,
CountAttackRequests: item.CountAttackRequests,
AttackBytes: item.AttackBytes,
CreatedAt: types.Int64(pieces[0]),
})
}
_, err = client.ServerDailyStatRPC().UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{
Stats: pbServerStats,
DomainStats: pbDomainStats,
})
return err
}

View File

@@ -8,7 +8,7 @@ import (
func TestTrafficStatManager_Add(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 10, 1, 0)
manager.Add(1, "goedge.cn", 1, 0, 0, 0)
}
t.Log(manager.itemMap)
}
@@ -16,7 +16,7 @@ func TestTrafficStatManager_Add(t *testing.T) {
func TestTrafficStatManager_Upload(t *testing.T) {
manager := NewTrafficStatManager()
for i := 0; i < 100; i++ {
manager.Add(1, 10, 1, 0)
manager.Add(1, "goedge.cn", 1, 0, 0, 0)
}
err := manager.Upload()
if err != nil {
@@ -30,6 +30,6 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) {
manager := NewTrafficStatManager()
for i := 0; i < b.N; i++ {
manager.Add(1, 1024, 1, 0)
manager.Add(1, "goedge.cn", 1024, 1, 0, 0)
}
}

View File

@@ -30,7 +30,7 @@ func (this *Piece) Add(key uint64, item *Item) () {
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64) (result int64) {
this.locker.Lock()
item, ok := this.m[key]
if ok {
if ok && item.expiredAt > time.Now().Unix() {
result = types.Int64(item.Value) + delta
item.Value = result
item.expiredAt = expiredAt

159
internal/utils/encrypt.go Normal file
View File

@@ -0,0 +1,159 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
stringutil "github.com/iwind/TeaGo/utils/string"
)
var (
simpleEncryptMagicKey = rands.HexString(32)
)
func init() {
events.On(events.EventReload, func() {
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret)
}
})
}
// SimpleEncrypt 加密特殊信息
func SimpleEncrypt(data []byte) []byte {
var method = &AES256CFBMethod{}
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
if err != nil {
logs.Println("[SimpleEncrypt]" + err.Error())
return data
}
dst, err := method.Encrypt(data)
if err != nil {
logs.Println("[SimpleEncrypt]" + err.Error())
return data
}
return dst
}
// SimpleDecrypt 解密特殊信息
func SimpleDecrypt(data []byte) []byte {
var method = &AES256CFBMethod{}
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
src, err := method.Decrypt(data)
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
return src
}
func SimpleEncryptMap(m maps.Map) (base64String string, err error) {
mJSON, err := json.Marshal(m)
if err != nil {
return "", err
}
data := SimpleEncrypt(mJSON)
return base64.StdEncoding.EncodeToString(data), nil
}
func SimpleDecryptMap(base64String string) (maps.Map, error) {
data, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
return nil, err
}
mJSON := SimpleDecrypt(data)
var result = maps.Map{}
err = json.Unmarshal(mJSON, &result)
if err != nil {
return nil, err
}
return result, nil
}
type AES256CFBMethod struct {
block cipher.Block
iv []byte
}
func (this *AES256CFBMethod) Init(key, iv []byte) error {
// 判断key是否为32长度
l := len(key)
if l > 32 {
key = key[:32]
} else if l < 32 {
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
return nil
}
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
r := recover()
if r != nil {
err = errors.New("encrypt failed")
}
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
r := recover()
if r != nil {
err = errors.New("decrypt failed")
}
}()
src = make([]byte, len(dst))
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
decrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,52 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package utils
import (
"github.com/iwind/TeaGo/maps"
"sync"
"testing"
)
func TestSimpleEncrypt(t *testing.T) {
var arr = []string{"Hello", "World", "People"}
for _, s := range arr {
var value = []byte(s)
encoded := SimpleEncrypt(value)
t.Log(encoded, string(encoded))
decoded := SimpleDecrypt(encoded)
t.Log(decoded, string(decoded))
}
}
func TestSimpleEncrypt_Concurrent(t *testing.T) {
wg := sync.WaitGroup{}
var arr = []string{"Hello", "World", "People"}
wg.Add(len(arr))
for _, s := range arr {
go func(s string) {
defer wg.Done()
t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s)))))
}(s)
}
wg.Wait()
}
func TestSimpleEncryptMap(t *testing.T) {
var m = maps.Map{
"s": "Hello",
"i": 20,
"b": true,
}
encodedResult, err := SimpleEncryptMap(m)
if err != nil {
t.Fatal(err)
}
t.Log("result:", encodedResult)
decodedResult, err := SimpleDecryptMap(encodedResult)
if err != nil {
t.Fatal(err)
}
t.Log(decodedResult)
}

View File

@@ -12,6 +12,7 @@ type List struct {
itemsMap map[int64]int64 // itemId => timestamp
locker sync.Mutex
ticker *time.Ticker
}
func NewList() *List {
@@ -21,10 +22,7 @@ func NewList() *List {
}
}
func (this *List) Add(itemId int64, expiredAt int64) {
if expiredAt <= time.Now().Unix() {
return
}
func (this *List) Add(itemId int64, expiresAt int64) {
this.locker.Lock()
defer this.locker.Unlock()
@@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) {
this.removeItem(itemId)
}
expireItemMap, ok := this.expireMap[expiredAt]
expireItemMap, ok := this.expireMap[expiresAt]
if ok {
expireItemMap[itemId] = true
} else {
expireItemMap = ItemMap{
itemId: true,
}
this.expireMap[expiredAt] = expireItemMap
this.expireMap[expiresAt] = expireItemMap
}
this.itemsMap[itemId] = expiredAt
this.itemsMap[itemId] = expiresAt
}
func (this *List) Remove(itemId int64) {
@@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
}
func (this *List) StartGC(callback func(itemId int64)) {
ticker := time.NewTicker(1 * time.Second)
this.ticker = time.NewTicker(1 * time.Second)
lastTimestamp := int64(0)
for range ticker.C {
for range this.ticker.C {
timestamp := time.Now().Unix()
if lastTimestamp == 0 {
lastTimestamp = timestamp - 3600
}
// 防止死循环
if lastTimestamp > timestamp {
continue
}
for i := lastTimestamp; i <= timestamp; i++ {
this.GC(timestamp, callback)
if timestamp >= lastTimestamp {
for i := lastTimestamp; i <= timestamp; i++ {
this.GC(i, callback)
}
} else {
for i := timestamp; i <= lastTimestamp; i++ {
this.GC(i, callback)
}
}
// 这样做是为了防止系统时钟突变

View File

@@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) {
list.Add(2, time.Now().Unix()+1)
list.Add(3, time.Now().Unix()+2)
list.Add(4, time.Now().Unix()+5)
list.Add(5, time.Now().Unix()+5)
list.Add(6, time.Now().Unix()+6)
list.Add(7, time.Now().Unix()+6)
list.Add(8, time.Now().Unix()+6)
go func() {
list.StartGC(func(itemId int64) {
@@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) {
})
}()
time.Sleep(10 * time.Second)
time.Sleep(20 * time.Second)
}
func TestList_ManyItems(t *testing.T) {

View File

@@ -8,7 +8,7 @@ import (
"strings"
)
// 将IP转换为整型
// IP2Long 将IP转换为整型
// 注意IPv6没有顺序
func IP2Long(ip string) uint64 {
if len(ip) == 0 {

View File

@@ -0,0 +1,35 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package jsonutils
import (
"encoding/json"
"github.com/iwind/TeaGo/maps"
)
func MapToObject(m maps.Map, ptr interface{}) error {
if m == nil {
return nil
}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(mJSON, ptr)
}
func ObjectToMap(ptr interface{}) (maps.Map, error) {
if ptr == nil {
return maps.Map{}, nil
}
ptrJSON, err := json.Marshal(ptr)
if err != nil {
return nil, err
}
var result = maps.Map{}
err = json.Unmarshal(ptrJSON, &result)
if err != nil {
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,46 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package jsonutils
import (
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"testing"
)
func TestMapToObject(t *testing.T) {
a := assert.NewAssertion(t)
type typeA struct {
B int `json:"b"`
C bool `json:"c"`
}
{
var obj = &typeA{B: 1, C: true}
m, err := ObjectToMap(obj)
if err != nil {
t.Fatal(err)
}
PrintT(m, t)
a.IsTrue(m.GetInt("b") == 1)
a.IsTrue(m.GetBool("c") == true)
}
{
var obj = &typeA{}
err := MapToObject(maps.Map{
"b": 1024,
"c": true,
}, obj)
if err != nil {
t.Fatal(err)
}
if obj == nil {
t.Fatal("obj should not be nil")
}
a.IsTrue(obj.B == 1024)
a.IsTrue(obj.C == true)
PrintT(obj, t)
}
}

View File

@@ -0,0 +1,17 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package jsonutils
import (
"encoding/json"
"testing"
)
func PrintT(obj interface{}, t *testing.T) {
data, err := json.MarshalIndent(obj, "", " ")
if err != nil {
t.Log(err)
} else {
t.Log(string(data))
}
}

View File

@@ -8,7 +8,23 @@ import (
type AllowAction struct {
}
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
func (this *AllowAction) Init(waf *WAF) error {
return nil
}
func (this *AllowAction) Code() string {
return ActionAllow
}
func (this *AllowAction) IsAttack() bool {
return false
}
func (this *AllowAction) WillChange() bool {
return false
}
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// do nothing
return true
}

View File

@@ -0,0 +1,21 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import "net/http"
type BaseAction struct {
}
// CloseConn 关闭连接
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
// 断开连接
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, err := hijack.Hijack()
if err == nil {
return conn.Close()
}
}
return nil
}

View File

@@ -23,12 +23,48 @@ type BlockAction struct {
StatusCode int `yaml:"statusCode" json:"statusCode"`
Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"`
Timeout int32 `yaml:"timeout" json:"timeout"`
}
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
func (this *BlockAction) Init(waf *WAF) error {
if waf.DefaultBlockAction != nil {
if this.StatusCode <= 0 {
this.StatusCode = waf.DefaultBlockAction.StatusCode
}
if len(this.Body) == 0 {
this.Body = waf.DefaultBlockAction.Body
}
if len(this.URL) == 0 {
this.URL = waf.DefaultBlockAction.URL
}
if this.Timeout <= 0 {
this.Timeout = waf.DefaultBlockAction.Timeout
}
}
return nil
}
func (this *BlockAction) Code() string {
return ActionBlock
}
func (this *BlockAction) IsAttack() bool {
return true
}
func (this *BlockAction) WillChange() bool {
return true
}
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
if this.Timeout > 0 {
// 加入到黑名单
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout))
}
if writer != nil {
// if status code eq 444, we close the connection
if this.StatusCode == 444 {
// close the connection
defer func() {
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, _ := hijack.Hijack()
@@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt
return
}
}
}
}()
// output response
if this.StatusCode > 0 {

View File

@@ -1,11 +1,14 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"github.com/iwind/TeaGo/maps"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"net/url"
"strings"
"time"
)
@@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32)
const (
CaptchaSeconds = 600 // 10 minutes
CaptchaPath = "/WAF/VERIFY/CAPTCHA"
)
type CaptchaAction struct {
Life int32 `yaml:"life" json:"life"`
Language string `yaml:"language" json:"language"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
}
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
// TEAWEB_CAPTCHA:
cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
if err == nil && cookie != nil && len(cookie.Value) > 32 {
m := cookie.Value[:32]
timestamp := cookie.Value[32:]
if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
return true
func (this *CaptchaAction) Init(waf *WAF) error {
return nil
}
func (this *CaptchaAction) Code() string {
return ActionCaptcha
}
func (this *CaptchaAction) IsAttack() bool {
return false
}
func (this *CaptchaAction) WillChange() bool {
return true
}
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
return true
}
refURL := request.WAFRaw().URL.String()
// 覆盖配置
if strings.HasPrefix(refURL, CaptchaPath) {
info := request.WAFRaw().URL.Query().Get("info")
if len(info) > 0 {
m, err := utils.SimpleDecryptMap(info)
if err == nil && m != nil {
refURL = m.GetString("url")
}
}
}
refURL := request.URL.String()
if len(request.Referer()) > 0 {
refURL = request.Referer()
var captchaConfig = maps.Map{
"action": this,
"timestamp": time.Now().Unix(),
"url": refURL,
"setId": set.Id,
}
http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
info, err := utils.SimpleEncryptMap(captchaConfig)
if err != nil {
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
return true
}
http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false
}

View File

@@ -0,0 +1,13 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
type ActionCategory = string
const (
ActionCategoryAllow ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow
ActionCategoryBlock ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock
ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify
)

View File

@@ -0,0 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import "github.com/iwind/TeaGo/maps"
type ActionConfig struct {
Code string `yaml:"code" json:"code"`
Options maps.Map `yaml:"options" json:"options"`
}

View File

@@ -2,11 +2,12 @@ package waf
import "reflect"
// action definition
// ActionDefinition action definition
type ActionDefinition struct {
Name string
Code ActionString
Description string
Category string // category: block, verify, allow
Instance ActionInterface
Type reflect.Type
}

View File

@@ -0,0 +1,71 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"net/http"
"net/url"
"time"
)
const (
Get302Path = "/WAF/VERIFY/GET"
)
// Get302Action
// 原理: origin url --> 302 verify url --> origin url
// TODO 将来支持meta refresh验证
type Get302Action struct {
BaseAction
Life int32 `yaml:"life" json:"life"`
}
func (this *Get302Action) Init(waf *WAF) error {
return nil
}
func (this *Get302Action) Code() string {
return ActionGet302
}
func (this *Get302Action) IsAttack() bool {
return false
}
func (this *Get302Action) WillChange() bool {
return true
}
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 仅限于Get
if request.WAFRaw().Method != http.MethodGet {
return true
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
return true
}
var m = maps.Map{
"url": request.WAFRaw().URL.String(),
"timestamp": time.Now().Unix(),
"life": this.Life,
"setId": set.Id,
}
info, err := utils.SimpleEncryptMap(m)
if err != nil {
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
return true
}
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
// 关闭连接
_ = this.CloseConn(writer)
return true
}

View File

@@ -10,13 +10,29 @@ type GoGroupAction struct {
GroupId string `yaml:"groupId" json:"groupId"`
}
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
group := waf.FindRuleGroup(this.GroupId)
if group == nil || !group.IsOn {
func (this *GoGroupAction) Init(waf *WAF) error {
return nil
}
func (this *GoGroupAction) Code() string {
return ActionGoGroup
}
func (this *GoGroupAction) IsAttack() bool {
return false
}
func (this *GoGroupAction) WillChange() bool {
return true
}
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
nextGroup := waf.FindRuleGroup(this.GroupId)
if nextGroup == nil || !nextGroup.IsOn {
return true
}
b, set, err := group.MatchRequest(request)
b, nextSet, err := nextGroup.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
@@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h
return true
}
actionObject := FindActionInstance(set.Action, set.ActionOptions)
if actionObject == nil {
return true
}
return actionObject.Perform(waf, request, writer)
return nextSet.PerformActions(waf, nextGroup, request, writer)
}

View File

@@ -11,17 +11,33 @@ type GoSetAction struct {
SetId string `yaml:"setId" json:"setId"`
}
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
group := waf.FindRuleGroup(this.GroupId)
if group == nil || !group.IsOn {
func (this *GoSetAction) Init(waf *WAF) error {
return nil
}
func (this *GoSetAction) Code() string {
return ActionGoSet
}
func (this *GoSetAction) IsAttack() bool {
return false
}
func (this *GoSetAction) WillChange() bool {
return true
}
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
nextGroup := waf.FindRuleGroup(this.GroupId)
if nextGroup == nil || !nextGroup.IsOn {
return true
}
set := group.FindRuleSet(this.SetId)
if set == nil || !set.IsOn {
nextSet := nextGroup.FindRuleSet(this.SetId)
if nextSet == nil || !nextSet.IsOn {
return true
}
b, err := set.MatchRequest(request)
b, err := nextSet.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
@@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt
if !b {
return true
}
actionObject := FindActionInstance(set.Action, set.ActionOptions)
if actionObject == nil {
return true
}
return actionObject.Perform(waf, request, writer)
return nextSet.PerformActions(waf, nextGroup, request, writer)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type ActionInterface interface {
// Init 初始化
Init(waf *WAF) error
// Code 代号
Code() string
// IsAttack 是否为拦截攻击动作
IsAttack() bool
// WillChange determine if the action will change the request
WillChange() bool
// Perform perform the action
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
}

View File

@@ -8,6 +8,22 @@ import (
type LogAction struct {
}
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
func (this *LogAction) Init(waf *WAF) error {
return nil
}
func (this *LogAction) Code() string {
return ActionLog
}
func (this *LogAction) IsAttack() bool {
return false
}
func (this *LogAction) WillChange() bool {
return false
}
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
return true
}

View File

@@ -0,0 +1,86 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"net/http"
"time"
)
type notifyTask struct {
ServerId int64
HttpFirewallPolicyId int64
HttpFirewallRuleGroupId int64
HttpFirewallRuleSetId int64
CreatedAt int64
}
var notifyChan = make(chan *notifyTask, 128)
func init() {
events.On(events.EventLoaded, func() {
go func() {
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error())
return
}
for task := range notifyChan {
_, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{
ServerId: task.ServerId,
HttpFirewallPolicyId: task.HttpFirewallPolicyId,
HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId,
HttpFirewallRuleSetId: task.HttpFirewallRuleSetId,
CreatedAt: task.CreatedAt,
})
if err != nil {
remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error())
}
}
}()
})
}
type NotifyAction struct {
}
func (this *NotifyAction) Init(waf *WAF) error {
return nil
}
func (this *NotifyAction) Code() string {
return ActionNotify
}
func (this *NotifyAction) IsAttack() bool {
return false
}
// WillChange determine if the action will change the request
func (this *NotifyAction) WillChange() bool {
return false
}
// Perform perform the action
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
select {
case notifyChan <- &notifyTask{
ServerId: request.WAFServerId(),
HttpFirewallPolicyId: types.Int64(waf.Id),
HttpFirewallRuleGroupId: types.Int64(group.Id),
HttpFirewallRuleSetId: types.Int64(set.Id),
CreatedAt: time.Now().Unix(),
}:
default:
}
return true
}

View File

@@ -0,0 +1,88 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"net/http"
"time"
)
type Post307Action struct {
Life int32 `yaml:"life" json:"life"`
BaseAction
}
func (this *Post307Action) Init(waf *WAF) error {
return nil
}
func (this *Post307Action) Code() string {
return ActionPost307
}
func (this *Post307Action) IsAttack() bool {
return false
}
func (this *Post307Action) WillChange() bool {
return true
}
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
var cookieName = "WAF_VALIDATOR_ID"
// 仅限于POST
if request.WAFRaw().Method != http.MethodPost {
return true
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
return true
}
// 判断是否有Cookie
cookie, err := request.WAFRaw().Cookie(cookieName)
if err == nil && cookie != nil {
m, err := utils.SimpleDecryptMap(cookie.Value)
if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 {
var life = m.GetInt64("life")
if life <= 0 {
life = 600 // 默认10分钟
}
var setId = m.GetString("setId")
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
return true
}
}
var m = maps.Map{
"timestamp": time.Now().Unix(),
"life": this.Life,
"setId": set.Id,
"remoteIP": request.WAFRemoteIP(),
}
info, err := utils.SimpleEncryptMap(m)
if err != nil {
remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error())
return true
}
// 设置Cookie
http.SetCookie(writer, &http.Cookie{
Name: cookieName,
Path: "/",
MaxAge: 10,
Value: info,
})
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
// 关闭连接
_ = this.CloseConn(writer)
return true
}

View File

@@ -0,0 +1,120 @@
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"strings"
"time"
)
type recordIPTask struct {
ip string
listId int64
expiredAt int64
level string
}
var recordIPTaskChan = make(chan *recordIPTask, 1024)
func init() {
events.On(events.EventLoaded, func() {
go func() {
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error())
return
}
for task := range recordIPTaskChan {
ipType := "ipv4"
if strings.Contains(task.ip, ":") {
ipType = "ipv6"
}
_, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
IpListId: task.listId,
IpFrom: task.ip,
IpTo: "",
ExpiredAt: task.expiredAt,
Reason: "触发WAF规则自动加入",
Type: ipType,
EventLevel: task.level,
})
if err != nil {
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
}
}
}()
})
}
type RecordIPAction struct {
BaseAction
Type string `yaml:"type" json:"type"`
IPListId int64 `yaml:"ipListId" json:"ipListId"`
Level string `yaml:"level" json:"level"`
Timeout int32 `yaml:"timeout" json:"timeout"`
}
func (this *RecordIPAction) Init(waf *WAF) error {
return nil
}
func (this *RecordIPAction) Code() string {
return ActionRecordIP
}
func (this *RecordIPAction) IsAttack() bool {
return this.Type == "black"
}
func (this *RecordIPAction) WillChange() bool {
return this.Type == "black"
}
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在本地白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
return true
}
// 先加入本地的黑名单
timeout := this.Timeout
if timeout <= 0 {
timeout = 86400 // 1天
}
expiredAt := time.Now().Unix() + int64(timeout)
if this.Type == "black" {
_ = this.CloseConn(writer)
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
} else {
// 加入本地白名单
timeout := this.Timeout
if timeout <= 0 {
timeout = 86400 // 1天
}
SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
}
// 上报
if this.IPListId > 0 {
select {
case recordIPTaskChan <- &recordIPTask{
ip: request.WAFRemoteIP(),
listId: this.IPListId,
expiredAt: expiredAt,
level: this.Level,
}:
default:
}
}
return this.Type != "black"
}

View File

@@ -0,0 +1,30 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type TagAction struct {
Tags []string `yaml:"tags" json:"tags"`
}
func (this *TagAction) Init(waf *WAF) error {
return nil
}
func (this *TagAction) Code() string {
return ActionTag
}
func (this *TagAction) IsAttack() bool {
return false
}
func (this *TagAction) WillChange() bool {
return false
}
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
return true
}

View File

@@ -1,21 +0,0 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type ActionString = string
const (
ActionLog = "log" // allow and log
ActionBlock = "block" // block
ActionCaptcha = "captcha" // block and show captcha
ActionAllow = "allow" // allow
ActionGoGroup = "go_group" // go to next rule group
ActionGoSet = "go_set" // go to next rule set
)
type ActionInterface interface {
Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
}

View File

@@ -0,0 +1,88 @@
package waf
import "reflect"
type ActionString = string
const (
ActionLog ActionString = "log" // allow and log
ActionBlock ActionString = "block" // block
ActionCaptcha ActionString = "captcha" // block and show captcha
ActionNotify ActionString = "notify" // 告警
ActionGet302 ActionString = "get_302" // 针对GET的302重定向认证
ActionPost307 ActionString = "post_307" // 针对POST的307重定向认证
ActionRecordIP ActionString = "record_ip" // 记录IP
ActionTag ActionString = "tag" // 标签
ActionAllow ActionString = "allow" // allow
ActionGoGroup ActionString = "go_group" // go to next rule group
ActionGoSet ActionString = "go_set" // go to next rule set
)
var AllActions = []*ActionDefinition{
{
Name: "阻止",
Code: ActionBlock,
Instance: new(BlockAction),
Type: reflect.TypeOf(new(BlockAction)).Elem(),
},
{
Name: "允许通过",
Code: ActionAllow,
Instance: new(AllowAction),
Type: reflect.TypeOf(new(AllowAction)).Elem(),
},
{
Name: "允许并记录日志",
Code: ActionLog,
Instance: new(LogAction),
Type: reflect.TypeOf(new(LogAction)).Elem(),
},
{
Name: "Captcha验证码",
Code: ActionCaptcha,
Instance: new(CaptchaAction),
Type: reflect.TypeOf(new(CaptchaAction)).Elem(),
},
{
Name: "告警",
Code: ActionNotify,
Instance: new(NotifyAction),
Type: reflect.TypeOf(new(NotifyAction)).Elem(),
},
{
Name: "GET 302",
Code: ActionGet302,
Instance: new(Get302Action),
Type: reflect.TypeOf(new(Get302Action)).Elem(),
},
{
Name: "POST 307",
Code: ActionPost307,
Instance: new(Post307Action),
Type: reflect.TypeOf(new(Post307Action)).Elem(),
},
{
Name: "记录IP",
Code: ActionRecordIP,
Instance: new(RecordIPAction),
Type: reflect.TypeOf(new(RecordIPAction)).Elem(),
},
{
Name: "标签",
Code: ActionTag,
Instance: new(TagAction),
Type: reflect.TypeOf(new(TagAction)).Elem(),
},
{
Name: "跳到下一个规则分组",
Code: ActionGoGroup,
Instance: new(GoGroupAction),
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
},
{
Name: "跳到下一个规则集",
Code: ActionGoSet,
Instance: new(GoSetAction),
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
},
}

View File

@@ -1,45 +1,12 @@
package waf
import (
"encoding/json"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/maps"
"reflect"
)
var AllActions = []*ActionDefinition{
{
Name: "阻止",
Code: ActionBlock,
Instance: new(BlockAction),
},
{
Name: "允许通过",
Code: ActionAllow,
Instance: new(AllowAction),
},
{
Name: "允许并记录日志",
Code: ActionLog,
Instance: new(LogAction),
},
{
Name: "Captcha验证码",
Code: ActionCaptcha,
Instance: new(CaptchaAction),
},
{
Name: "跳到下一个规则分组",
Code: ActionGoGroup,
Instance: new(GoGroupAction),
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
},
{
Name: "跳到下一个规则集",
Code: ActionGoSet,
Instance: new(GoSetAction),
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
},
}
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
for _, def := range AllActions {
if def.Code == action {
@@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
instance := ptrValue.Interface().(ActionInterface)
if len(options) > 0 {
count := def.Type.NumField()
for i := 0; i < count; i++ {
field := def.Type.Field(i)
tag, ok := field.Tag.Lookup("yaml")
if ok {
v, ok := options[tag]
if ok && reflect.TypeOf(v) == field.Type {
ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
}
optionsJSON, err := json.Marshal(options)
if err != nil {
remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error())
} else {
err = json.Unmarshal(optionsJSON, instance)
if err != nil {
remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error())
}
}
}

View File

@@ -2,6 +2,7 @@ package waf
import (
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"runtime"
"testing"
@@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) {
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
}
func TestFindActionInstance_Options(t *testing.T) {
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
"timeout": 3600,
}), t)
}
func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {

View File

@@ -3,29 +3,64 @@ package waf
import (
"bytes"
"encoding/base64"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"github.com/iwind/TeaGo/types"
"net/http"
"strconv"
"strings"
"time"
)
var captchaValidator = &CaptchaValidator{}
var captchaValidator = NewCaptchaValidator()
type CaptchaValidator struct {
}
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
this.validate(request, writer)
func NewCaptchaValidator() *CaptchaValidator {
return &CaptchaValidator{}
}
func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
var info = request.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
}
m, err := utils.SimpleDecryptMap(info)
if err != nil {
_, _ = writer.Write([]byte("invalid request"))
return
}
timestamp := m.GetInt64("timestamp")
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
return
}
var actionConfig = &CaptchaAction{}
err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
if err != nil {
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
return
}
var setId = m.GetInt64("setId")
var originURL = m.GetString("url")
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
this.validate(actionConfig, setId, originURL, request, writer)
} else {
this.show(request, writer)
this.show(actionConfig, request, writer)
}
}
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
// show captcha
captchaId := captcha.NewLen(6)
buf := bytes.NewBuffer([]byte{})
@@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon
return
}
var lang = actionConfig.Language
if len(lang) == 0 {
acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
if len(acceptLanguage) > 0 {
langIndex := strings.Index(acceptLanguage, ",")
if langIndex > 0 {
lang = acceptLanguage[:langIndex]
}
}
}
if len(lang) == 0 {
lang = "en-US"
}
var msgTitle = ""
var msgPrompt = ""
var msgButtonTitle = ""
switch lang {
case "en-US":
msgTitle = "Verify Yourself"
msgPrompt = "Input verify code above:"
msgButtonTitle = "Verify Yourself"
case "zh-CN":
msgTitle = "身份验证"
msgPrompt = "请输入上面的验证码"
msgButtonTitle = "提交验证"
default:
msgTitle = "Verify Yourself"
msgPrompt = "Input verify code above:"
msgButtonTitle = "Verify Yourself"
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = writer.Write([]byte(`<!DOCTYPE html>
<html>
<head>
<title>Verify Yourself</title>
<title>` + msgTitle + `</title>
<script type="text/javascript">
if (window.addEventListener != null) {
window.addEventListener("load", function () {
document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus()
})
}
</script>
</head>
<body>
<form method="POST">
<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
<div>
<p>Input verify code above:</p>
<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
<p>` + msgPrompt + `</p>
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px; width: 160px"/>
</div>
<div>
<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
</div>
</form>
</body>
</html>`))
}
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
if len(captchaId) > 0 {
captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
if captcha.VerifyString(captchaId, captchaCode) {
// set cookie
timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
m := stringutil.Md5(captchaSalt + timestamp)
http.SetCookie(writer, &http.Cookie{
Name: "TEAWEB_WAF_CAPTCHA",
Value: m + timestamp,
MaxAge: CaptchaSeconds, // TODO 这个时间可以设置
Path: "/", // all of dirs
})
var life = CaptchaSeconds
if actionConfig.Life > 0 {
life = types.Int(actionConfig.Life)
}
rawURL := request.URL.Query().Get("url")
http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
// 加入到白名单
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
return false
} else {
http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
}
}

View File

@@ -5,14 +5,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"net"
"regexp"
"strings"
"sync"
"time"
)
// ${cc.arg}
// CCCheckpoint ${cc.arg}
// TODO implement more traffic rules
type CCCheckpoint struct {
Checkpoint
@@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() {
this.cache = ttlcache.NewCache()
}
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = 0
if this.cache == nil {
@@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
var key = ""
switch userType {
case "ip":
key = this.ip(req)
key = req.WAFRemoteIP()
case "cookie":
if len(userField) == 0 {
key = this.ip(req)
key = req.WAFRemoteIP()
} else {
cookie, _ := req.Cookie(userField)
cookie, _ := req.WAFRaw().Cookie(userField)
if cookie != nil {
v := cookie.Value
if userIndex > 0 && len(v) > userIndex {
@@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "get":
if len(userField) == 0 {
key = this.ip(req)
key = req.WAFRemoteIP()
} else {
v := req.URL.Query().Get(userField)
v := req.WAFRaw().URL.Query().Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
@@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "post":
if len(userField) == 0 {
key = this.ip(req)
key = req.WAFRemoteIP()
} else {
v := req.PostFormValue(userField)
v := req.WAFRaw().PostFormValue(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
@@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "header":
if len(userField) == 0 {
key = this.ip(req)
key = req.WAFRemoteIP()
} else {
v := req.Header.Get(userField)
v := req.WAFRaw().Header.Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
default:
key = this.ip(req)
key = req.WAFRemoteIP()
}
if len(key) == 0 {
key = this.ip(req)
key = req.WAFRemoteIP()
}
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
}
@@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
return
}
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
@@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() {
this.cache = nil
}
}
func (this *CCCheckpoint) ip(req *requests.Request) string {
// X-Forwarded-For
forwardedFor := req.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
return forwardedFor[:commaIndex]
}
return forwardedFor
}
// Real-IP
{
realIP, ok := req.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
return realIP[0]
}
}
// Real-Ip
{
realIP, ok := req.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
return realIP[0]
}
}
// Remote-Addr
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return host
}
return req.RemoteAddr
}

View File

@@ -0,0 +1,48 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"strings"
"time"
)
var ccCache = ttlcache.NewCache(ttlcache.NewPiecesOption(32))
// CC2Checkpoint 新的CC
type CC2Checkpoint struct {
Checkpoint
}
func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var keys = options.GetSlice("keys")
var keyValues = []string{}
for _, key := range keys {
keyValues = append(keyValues, req.Format(types.String(key)))
}
if len(keyValues) == 0 {
return
}
var period = options.GetInt64("period")
if period <= 0 {
period = 60
}
var threshold = options.GetInt64("threshold")
if threshold <= 0 {
threshold = 1000
}
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period)
return
}
func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return
}

View File

@@ -2,6 +2,7 @@ package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"net/http"
"testing"
)
@@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(raw)
req.RemoteAddr = "127.0.0.1"
req := requests.NewTestRequest(raw)
req.WAFRaw().RemoteAddr = "127.0.0.1"
checkpoint := new(CCCheckpoint)
checkpoint.Init()
checkpoint.Start()
options := map[string]string{
options := maps.Map{
"period": "5",
}
t.Log(checkpoint.RequestValue(req, "requests", options))
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.1"
req.WAFRaw().RemoteAddr = "127.0.0.1"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
}

View File

@@ -5,32 +5,32 @@ import (
"github.com/iwind/TeaGo/maps"
)
// Check Point
// CheckpointInterface Check Point
type CheckpointInterface interface {
// initialize
// Init initialize
Init()
// is request?
// IsRequest is request?
IsRequest() bool
// is composed?
// IsComposed is composed?
IsComposed() bool
// get request value
RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
// RequestValue get request value
RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
// get response value
ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
// ResponseValue get response value
ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
// param option list
// ParamOptions param option list
ParamOptions() *ParamOptions
// options
// Options options
Options() []OptionInterface
// start
// Start start
Start()
// stop
// Stop stop
Stop()
}

View File

@@ -5,32 +5,34 @@ import (
"github.com/iwind/TeaGo/maps"
)
// ${requestAll}
// RequestAllCheckpoint ${requestAll}
type RequestAllCheckpoint struct {
Checkpoint
}
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
valueBytes := []byte{}
if len(req.RequestURI) > 0 {
valueBytes = append(valueBytes, req.RequestURI...)
} else if req.URL != nil {
valueBytes = append(valueBytes, req.URL.RequestURI()...)
if len(req.WAFRaw().RequestURI) > 0 {
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
} else if req.WAFRaw().URL != nil {
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
}
if req.Body != nil {
if req.WAFRaw().Body != nil {
valueBytes = append(valueBytes, ' ')
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
var bodyData = req.WAFGetCacheBody()
if len(bodyData) == 0 {
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
bodyData = data
req.WAFSetCacheBody(data)
req.WAFRestoreBody(data)
}
valueBytes = append(valueBytes, req.BodyData...)
valueBytes = append(valueBytes, bodyData...)
}
value = valueBytes
@@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri
return
}
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
if this.IsRequest() {
return this.RequestValue(req, param, options)

View File

@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
}
checkpoint := new(RequestAllCheckpoint)
v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if sysErr != nil {
t.Fatal(sysErr)
}
@@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
}
checkpoint := new(RequestBodyCheckpoint)
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
checkpoint := new(RequestAllCheckpoint)
for i := 0; i < b.N; i++ {
_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
_, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
}
}

View File

@@ -9,11 +9,11 @@ type RequestArgCheckpoint struct {
Checkpoint
}
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return req.URL.Query().Get(param), nil, nil
func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return req.WAFRaw().URL.Query().Get(param), nil, nil
}
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))

View File

@@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct {
Checkpoint
}
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.URL.RawQuery
func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().URL.RawQuery
return
}
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -5,31 +5,33 @@ import (
"github.com/iwind/TeaGo/maps"
)
// ${requestBody}
// RequestBodyCheckpoint ${requestBody}
type RequestBodyCheckpoint struct {
Checkpoint
}
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if req.Body == nil {
func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if req.WAFRaw().Body == nil {
value = ""
return
}
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
var bodyData = req.WAFGetCacheBody()
if len(bodyData) == 0 {
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
bodyData = data
req.WAFSetCacheBody(data)
req.WAFRestoreBody(data)
}
return req.BodyData, nil, nil
return bodyData, nil, nil
}
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -11,19 +11,20 @@ import (
)
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
if err != nil {
t.Fatal(err)
}
var req = requests.NewTestRequest(rawReq)
checkpoint := new(RequestBodyCheckpoint)
t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
t.Log(checkpoint.RequestValue(req, "", nil))
body, err := ioutil.ReadAll(req.Body)
body, err := ioutil.ReadAll(rawReq.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
t.Log(string(req.WAFGetCacheBody()))
}
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
@@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
}
checkpoint := new(RequestBodyCheckpoint)
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct {
Checkpoint
}
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Header.Get("Content-Type")
func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().Header.Get("Content-Type")
return
}
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct {
Checkpoint
}
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
cookie, err := req.Cookie(param)
func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
cookie, err := req.WAFRaw().Cookie(param)
if err != nil {
value = ""
return
@@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s
return
}
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct {
Checkpoint
}
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var cookies = []string{}
for _, cookie := range req.Cookies() {
for _, cookie := range req.WAFRaw().Cookies() {
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
}
value = strings.Join(cookies, "&")
return
}
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -6,33 +6,35 @@ import (
"net/url"
)
// ${requestForm.arg}
// RequestFormArgCheckpoint ${requestForm.arg}
type RequestFormArgCheckpoint struct {
Checkpoint
}
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if req.Body == nil {
func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if req.WAFRaw().Body == nil {
value = ""
return
}
if len(req.BodyData) == 0 {
data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
var bodyData = req.WAFGetCacheBody()
if len(bodyData) == 0 {
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
bodyData = data
req.WAFSetCacheBody(data)
req.WAFRestoreBody(data)
}
// TODO improve performance
values, _ := url.ParseQuery(string(req.BodyData))
values, _ := url.ParseQuery(string(bodyData))
return values.Get(param), nil, nil
}
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req := requests.NewTestRequest(rawReq)
req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestFormArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
@@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "Hello", nil))
t.Log(checkpoint.RequestValue(req, "encoded", nil))
body, err := ioutil.ReadAll(req.Body)
body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}

View File

@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
return true
}
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = false
headers := options.GetSlice("headers")
@@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
length := options.GetInt("length")
for _, header := range headers {
v := req.Header.Get(types.String(header))
v := req.WAFRaw().Header.Get(types.String(header))
if len(v) > length {
value = true
break
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
return
}
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return
}

View File

@@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct {
Checkpoint
}
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
v, found := req.Header[param]
func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
v, found := req.WAFRaw().Header[param]
if !found {
value = ""
return
@@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s
return
}
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct {
Checkpoint
}
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var headers = []string{}
for k, v := range req.Header {
for k, v := range req.WAFRaw().Header {
for _, subV := range v {
headers = append(headers, k+": "+subV)
}
@@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param
return
}
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,12 +9,12 @@ type RequestHostCheckpoint struct {
Checkpoint
}
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Host
func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().Host
return
}
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -12,8 +12,8 @@ func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req.Header.Set("Host", "cloud.teaos.cn")
req := requests.NewTestRequest(rawReq)
req.WAFRaw().Header.Set("Host", "cloud.teaos.cn")
checkpoint := new(RequestHostCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))

View File

@@ -8,24 +8,27 @@ import (
"strings"
)
// ${requestJSON.arg}
// RequestJSONArgCheckpoint ${requestJSON.arg}
type RequestJSONArgCheckpoint struct {
Checkpoint
}
func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var bodyData = req.WAFGetCacheBody()
if len(bodyData) == 0 {
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
defer req.RestoreBody(data)
bodyData = data
req.WAFSetCacheBody(data)
defer req.WAFRestoreBody(data)
}
// TODO improve performance
var m interface{} = nil
err := json.Unmarshal(req.BodyData, &m)
err := json.Unmarshal(bodyData, &m)
if err != nil || m == nil {
return "", nil, err
}
@@ -37,7 +40,7 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param
return "", nil, nil
}
func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -20,7 +20,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -31,7 +31,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "books", nil))
t.Log(checkpoint.RequestValue(req, "books.1", nil))
body, err := ioutil.ReadAll(req.Body)
body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
@@ -50,7 +50,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -61,7 +61,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
body, err := ioutil.ReadAll(req.Body)
body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
@@ -80,7 +80,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -91,7 +91,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
body, err := ioutil.ReadAll(req.Body)
body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}

View File

@@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct {
Checkpoint
}
func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.ContentLength
func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().ContentLength
return
}
func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct {
Checkpoint
}
func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Method
func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().Method
return
}
func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,11 +9,11 @@ type RequestPathCheckpoint struct {
Checkpoint
}
func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return req.URL.Path, nil, nil
func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return req.WAFRaw().URL.Path, nil, nil
}
func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -12,7 +12,7 @@ func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestPathCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}

View File

@@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct {
Checkpoint
}
func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Proto
func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().Proto
return
}
func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -10,17 +10,17 @@ type RequestRawRemoteAddrCheckpoint struct {
Checkpoint
}
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
host, _, err := net.SplitHostPort(req.RemoteAddr)
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
if err == nil {
value = host
} else {
value = req.RemoteAddr
value = req.WAFRaw().RemoteAddr
}
return
}
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct {
Checkpoint
}
func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Referer()
func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().Referer()
return
}
func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -3,56 +3,18 @@ package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"net"
"strings"
)
type RequestRemoteAddrCheckpoint struct {
Checkpoint
}
func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
// X-Forwarded-For
forwardedFor := req.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
value = forwardedFor[:commaIndex]
return
}
value = forwardedFor
return
}
// Real-IP
{
realIP, ok := req.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
value = realIP[0]
return
}
}
// Real-Ip
{
realIP, ok := req.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
value = realIP[0]
return
}
}
// Remote-Addr
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
value = host
} else {
value = req.RemoteAddr
}
func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRemoteIP()
return
}
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -11,8 +11,8 @@ type RequestRemotePortCheckpoint struct {
Checkpoint
}
func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
_, port, err := net.SplitHostPort(req.RemoteAddr)
func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
_, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
if err == nil {
value = types.Int(port)
} else {
@@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, par
return
}
func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,8 +9,8 @@ type RequestRemoteUserCheckpoint struct {
Checkpoint
}
func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
username, _, ok := req.BasicAuth()
func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
username, _, ok := req.WAFRaw().BasicAuth()
if !ok {
value = ""
return
@@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, par
return
}
func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct {
Checkpoint
}
func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.URL.Scheme
func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.WAFRaw().URL.Scheme
return
}
func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

View File

@@ -12,7 +12,7 @@ func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestSchemeCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}

View File

@@ -11,63 +11,65 @@ import (
"strings"
)
// ${requestUpload.arg}
// RequestUploadCheckpoint ${requestUpload.arg}
type RequestUploadCheckpoint struct {
Checkpoint
}
func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
if param == "minSize" || param == "maxSize" {
value = 0
}
if req.Method != http.MethodPost {
if req.WAFRaw().Method != http.MethodPost {
return
}
if req.Body == nil {
if req.WAFRaw().Body == nil {
return
}
if req.MultipartForm == nil {
if len(req.BodyData) == 0 {
data, err := req.ReadBody(32 * 1024 * 1024)
if req.WAFRaw().MultipartForm == nil {
var bodyData = req.WAFGetCacheBody()
if len(bodyData) == 0 {
data, err := req.WAFReadBody(32 * 1024 * 1024)
if err != nil {
sysErr = err
return
}
req.BodyData = data
defer req.RestoreBody(data)
bodyData = data
req.WAFSetCacheBody(data)
defer req.WAFRestoreBody(data)
}
oldBody := req.Body
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
oldBody := req.WAFRaw().Body
req.WAFRaw().Body = ioutil.NopCloser(bytes.NewBuffer(bodyData))
err := req.ParseMultipartForm(32 * 1024 * 1024)
err := req.WAFRaw().ParseMultipartForm(32 * 1024 * 1024)
// 还原
req.Body = oldBody
req.WAFRaw().Body = oldBody
if err != nil {
userErr = err
return
}
if req.MultipartForm == nil {
if req.WAFRaw().MultipartForm == nil {
return
}
}
if param == "field" { // field
fields := []string{}
for field := range req.MultipartForm.File {
for field := range req.WAFRaw().MultipartForm.File {
fields = append(fields, field)
}
value = strings.Join(fields, ",")
} else if param == "minSize" { // minSize
minSize := int64(0)
for _, files := range req.MultipartForm.File {
for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if minSize == 0 || minSize > file.Size {
minSize = file.Size
@@ -77,7 +79,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = minSize
} else if param == "maxSize" { // maxSize
maxSize := int64(0)
for _, files := range req.MultipartForm.File {
for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if maxSize < file.Size {
maxSize = file.Size
@@ -87,7 +89,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = maxSize
} else if param == "name" { // name
names := []string{}
for _, files := range req.MultipartForm.File {
for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if !lists.ContainsString(names, file.Filename) {
names = append(names, file.Filename)
@@ -97,7 +99,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = strings.Join(names, ",")
} else if param == "ext" { // ext
extensions := []string{}
for _, files := range req.MultipartForm.File {
for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if len(file.Filename) > 0 {
exit := strings.ToLower(filepath.Ext(file.Filename))
@@ -113,7 +115,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
return
}
func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}

Some files were not shown because too many files have changed in this diff Show More