Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09dcf0d712 | ||
|
|
60aebd9306 | ||
|
|
04191d04d3 | ||
|
|
b80a5c525f | ||
|
|
265c1e5312 | ||
|
|
2723f705b6 | ||
|
|
b4cddd6341 | ||
|
|
5636a81d48 | ||
|
|
d8059960de | ||
|
|
17af4064af | ||
|
|
15f37d2c93 | ||
|
|
6dc3aa8cb7 | ||
|
|
900cccf2f1 | ||
|
|
1fec88dfc6 | ||
|
|
7da9363336 | ||
|
|
d82e633bba | ||
|
|
b363bbaafd | ||
|
|
92a20e3c9a | ||
|
|
5742dfb263 | ||
|
|
0ae63511d5 |
@@ -473,6 +473,7 @@ func (this *FileListDB) initTables(times int) error {
|
||||
{
|
||||
// expiredAt - 过期时间,用来判断有无过期
|
||||
// staleAt - 过时缓存最大时间,用来清理缓存
|
||||
// 不对 hash 增加 unique 参数,是尽可能避免产生 malformed 错误
|
||||
_, err := this.writeDB.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemsTableName + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"hash" varchar(32),
|
||||
@@ -498,7 +499,7 @@ ON "` + this.itemsTableName + `" (
|
||||
"staleAt" ASC
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "hash"
|
||||
CREATE INDEX IF NOT EXISTS "hash"
|
||||
ON "` + this.itemsTableName + `" (
|
||||
"hash" ASC
|
||||
);
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
"math"
|
||||
@@ -58,6 +59,7 @@ const (
|
||||
HotItemLifeSeconds int64 = 3600 // 热点数据生命周期
|
||||
FileToMemoryMaxSize = 32 * sizes.M // 可以从文件写入到内存的最大文件尺寸
|
||||
FileTmpSuffix = ".tmp"
|
||||
MinDiskSpace = 5 << 30 // 当前磁盘最小剩余空间
|
||||
)
|
||||
|
||||
var sharedWritingFileKeyMap = map[string]zero.Zero{} // key => bool
|
||||
@@ -90,6 +92,8 @@ type FileStorage struct {
|
||||
ignoreKeys *setutils.FixedSet
|
||||
|
||||
openFileCache *OpenFileCache
|
||||
|
||||
diskIsFull bool
|
||||
}
|
||||
|
||||
func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage {
|
||||
@@ -287,6 +291,9 @@ func (this *FileStorage) Init() error {
|
||||
// open file cache
|
||||
this.initOpenFileCache()
|
||||
|
||||
// 检查磁盘空间
|
||||
this.checkDiskSpace()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -397,6 +404,11 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
|
||||
return nil, ErrWritingUnavailable
|
||||
}
|
||||
|
||||
// 当前磁盘可用容量是否严重不足
|
||||
if this.diskIsFull {
|
||||
return nil, NewCapacityError("the disk is full")
|
||||
}
|
||||
|
||||
// 是否已忽略
|
||||
if this.ignoreKeys.Has(key) {
|
||||
return nil, ErrEntityTooLarge
|
||||
@@ -938,18 +950,25 @@ func (this *FileStorage) initList() error {
|
||||
|
||||
// 清理任务
|
||||
func (this *FileStorage) purgeLoop() {
|
||||
// 检查磁盘剩余空间
|
||||
this.checkDiskSpace()
|
||||
|
||||
// 计算是否应该开启LFU清理
|
||||
var capacityBytes = this.policy.CapacityBytes()
|
||||
var startLFU = false
|
||||
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
|
||||
var lfuFreePercent = this.policy.PersistenceLFUFreePercent
|
||||
if lfuFreePercent <= 0 {
|
||||
lfuFreePercent = 5
|
||||
}
|
||||
if capacityBytes > 0 {
|
||||
if lfuFreePercent < 100 {
|
||||
if usedPercent >= 100-lfuFreePercent {
|
||||
startLFU = true
|
||||
if this.diskIsFull {
|
||||
startLFU = true
|
||||
} else {
|
||||
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
|
||||
if capacityBytes > 0 {
|
||||
if lfuFreePercent < 100 {
|
||||
if usedPercent >= 100-lfuFreePercent {
|
||||
startLFU = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1327,3 +1346,15 @@ func (this *FileStorage) runMemoryStorageSafety(f func(memoryStorage *MemoryStor
|
||||
f(memoryStorage)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查磁盘剩余空间
|
||||
func (this *FileStorage) checkDiskSpace() {
|
||||
if this.options != nil && len(this.options.Dir) > 0 {
|
||||
var stat unix.Statfs_t
|
||||
err := unix.Statfs(this.options.Dir, &stat)
|
||||
if err == nil {
|
||||
var availableBytes = stat.Bavail * uint64(stat.Bsize)
|
||||
this.diskIsFull = availableBytes < MinDiskSpace
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.5.4"
|
||||
Version = "0.5.6"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -74,10 +75,16 @@ func (this *IPTablesAction) runAction(action string, listType IPListType, item *
|
||||
}
|
||||
|
||||
func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
|
||||
// 暂时不支持ipv6
|
||||
// TODO 将来支持ipv6
|
||||
if utils.IsIPv6(item.IpFrom) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if item.Type == "all" {
|
||||
return nil
|
||||
}
|
||||
path := this.config.Path
|
||||
var path = this.config.Path
|
||||
var err error
|
||||
if len(path) == 0 {
|
||||
path, err = exec.LookPath("iptables")
|
||||
@@ -88,6 +95,7 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
|
||||
this.iptablesNotFound = true
|
||||
return err
|
||||
}
|
||||
this.config.Path = path
|
||||
}
|
||||
iptablesAction := ""
|
||||
switch action {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
@@ -141,6 +142,12 @@ func (this *IPListManager) init() {
|
||||
}
|
||||
|
||||
func (this *IPListManager) loop() error {
|
||||
// 是否同步IP名单
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil && !nodeConfig.EnableIPLists {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
hasNext, err := this.fetch()
|
||||
if err != nil {
|
||||
|
||||
@@ -1603,9 +1603,25 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
|
||||
header.Del(k)
|
||||
k = strings.ReplaceAll(k, "-Websocket-", "-WebSocket-")
|
||||
header[k] = v
|
||||
} else if k == "Www-Authenticate" {
|
||||
} else if strings.HasPrefix(k, "Sec-Ch") {
|
||||
header.Del(k)
|
||||
header["WWW-Authenticate"] = v
|
||||
k = strings.ReplaceAll(k, "Sec-Ch-Ua", "Sec-CH-UA")
|
||||
header[k] = v
|
||||
} else {
|
||||
switch k {
|
||||
case "Www-Authenticate":
|
||||
header.Del(k)
|
||||
header["WWW-Authenticate"] = v
|
||||
case "A-Im":
|
||||
header.Del(k)
|
||||
header["A-IM"] = v
|
||||
case "Content-Md5":
|
||||
header.Del(k)
|
||||
header["Content-MD5"] = v
|
||||
case "Sec-Gpc":
|
||||
header.Del(k)
|
||||
header["Content-GPC"] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
|
||||
// 读取缓存
|
||||
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
// 需要动态Upgrade的不缓存
|
||||
if len(this.RawReq.Header.Get("Upgrade")) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
this.cacheCanTryStale = false
|
||||
|
||||
var cachePolicy = this.ReqServer.HTTPCachePolicy
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -13,7 +17,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
if this.web.MergeSlashes {
|
||||
urlPath = utils.CleanPath(urlPath)
|
||||
}
|
||||
fullURL := this.requestScheme() + "://" + this.ReqHost + urlPath
|
||||
var fullURL = this.requestScheme() + "://" + this.ReqHost + urlPath
|
||||
for _, u := range this.web.HostRedirects {
|
||||
if !u.IsOn {
|
||||
continue
|
||||
@@ -21,11 +25,50 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
if !u.MatchRequest(this.Format) {
|
||||
continue
|
||||
}
|
||||
if u.MatchPrefix { // 匹配前缀
|
||||
if strings.HasPrefix(fullURL, u.BeforeURL) {
|
||||
afterURL := u.AfterURL
|
||||
if u.KeepRequestURI {
|
||||
afterURL += this.RawReq.URL.RequestURI()
|
||||
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
|
||||
if u.MatchPrefix { // 匹配前缀
|
||||
if strings.HasPrefix(fullURL, u.BeforeURL) {
|
||||
afterURL := u.AfterURL
|
||||
if u.KeepRequestURI {
|
||||
afterURL += this.RawReq.URL.RequestURI()
|
||||
}
|
||||
|
||||
// 前后是否一致
|
||||
if fullURL == afterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
}
|
||||
} else if u.MatchRegexp { // 正则匹配
|
||||
var reg = u.BeforeURLRegexp()
|
||||
if reg == nil {
|
||||
continue
|
||||
}
|
||||
var matches = reg.FindStringSubmatch(fullURL)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
var afterURL = u.AfterURL
|
||||
for i, match := range matches {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
|
||||
}
|
||||
|
||||
var subNames = reg.SubexpNames()
|
||||
if len(subNames) > 0 {
|
||||
for _, subName := range subNames {
|
||||
if len(subName) > 0 {
|
||||
index := reg.SubexpIndex(subName)
|
||||
if index > -1 {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 前后是否一致
|
||||
@@ -33,69 +76,6 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
}
|
||||
return true
|
||||
}
|
||||
} else if u.MatchRegexp { // 正则匹配
|
||||
reg := u.BeforeURLRegexp()
|
||||
if reg == nil {
|
||||
continue
|
||||
}
|
||||
matches := reg.FindStringSubmatch(fullURL)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
afterURL := u.AfterURL
|
||||
for i, match := range matches {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
|
||||
}
|
||||
|
||||
subNames := reg.SubexpNames()
|
||||
if len(subNames) > 0 {
|
||||
for _, subName := range subNames {
|
||||
if len(subName) > 0 {
|
||||
index := reg.SubexpIndex(subName)
|
||||
if index > -1 {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 前后是否一致
|
||||
if fullURL == afterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.KeepArgs {
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
}
|
||||
return true
|
||||
} else { // 精准匹配
|
||||
if fullURL == u.RealBeforeURL() {
|
||||
// 前后是否一致
|
||||
if fullURL == u.AfterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
var afterURL = u.AfterURL
|
||||
if u.KeepArgs {
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
@@ -104,12 +84,110 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
} else { // 精准匹配
|
||||
if fullURL == u.RealBeforeURL() {
|
||||
// 前后是否一致
|
||||
if fullURL == u.AfterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
var afterURL = u.AfterURL
|
||||
if u.KeepArgs {
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
}
|
||||
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if u.Type == serverconfigs.HTTPHostRedirectTypeDomain {
|
||||
if len(u.DomainAfter) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果跳转前后域名一致,则终止
|
||||
if u.DomainAfter == this.ReqHost {
|
||||
return false
|
||||
}
|
||||
|
||||
var scheme = u.DomainAfterScheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = this.requestScheme()
|
||||
}
|
||||
if u.DomainsAll || configutils.MatchDomains(u.DomainsBefore, this.ReqHost) {
|
||||
var afterURL = scheme + "://" + u.DomainAfter + urlPath
|
||||
if fullURL == afterURL {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
}
|
||||
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
|
||||
if u.PortAfter <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var scheme = u.PortAfterScheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = this.requestScheme()
|
||||
}
|
||||
|
||||
reqHost, reqPort, _ := net.SplitHostPort(this.ReqHost)
|
||||
if len(reqHost) == 0 {
|
||||
reqHost = this.ReqHost
|
||||
}
|
||||
if len(reqPort) == 0 {
|
||||
switch this.requestScheme() {
|
||||
case "http":
|
||||
reqPort = "80"
|
||||
case "https":
|
||||
reqPort = "443"
|
||||
}
|
||||
}
|
||||
|
||||
// 如果跳转前后端口一致,则终止
|
||||
if reqPort == types.String(u.PortAfter) {
|
||||
return false
|
||||
}
|
||||
|
||||
var containsPort = false
|
||||
if u.PortsAll {
|
||||
containsPort = true
|
||||
} else {
|
||||
containsPort = u.ContainsPort(types.Int(reqPort))
|
||||
}
|
||||
if containsPort {
|
||||
var newReqHost = reqHost
|
||||
if !((scheme == "http" && u.PortAfter == 80) || (scheme == "https" && u.PortAfter == 443)) {
|
||||
newReqHost += ":" + types.String(u.PortAfter)
|
||||
}
|
||||
var afterURL = scheme + "://" + newReqHost + urlPath
|
||||
if fullURL == afterURL {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
if u.Status <= 0 {
|
||||
u.Status = http.StatusTemporaryRedirect
|
||||
}
|
||||
this.processResponseHeaders(this.writer.Header(), u.Status)
|
||||
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,27 +51,46 @@ func (this *HTTPRequest) log() {
|
||||
addr = addr[:index]
|
||||
}
|
||||
|
||||
var serverGlobalConfig = this.nodeConfig.GlobalServerConfig
|
||||
|
||||
// 请求Cookie
|
||||
var cookies = map[string]string{}
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
|
||||
for _, cookie := range this.RawReq.Cookies() {
|
||||
cookies[cookie.Name] = cookie.Value
|
||||
var enableCookies = false
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableCookies {
|
||||
enableCookies = true
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
|
||||
for _, cookie := range this.RawReq.Cookies() {
|
||||
cookies[cookie.Name] = cookie.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 请求Header
|
||||
var pbReqHeader = map[string]*pb.Strings{}
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
|
||||
for k, v := range this.RawReq.Header {
|
||||
pbReqHeader[k] = &pb.Strings{Values: v}
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableRequestHeaders {
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
|
||||
// 是否只记录通用Header
|
||||
var commonHeadersOnly = serverGlobalConfig != nil && serverGlobalConfig.HTTPAccessLog.CommonRequestHeadersOnly
|
||||
|
||||
for k, v := range this.RawReq.Header {
|
||||
if commonHeadersOnly && !serverconfigs.IsCommonRequestHeader(k) {
|
||||
continue
|
||||
}
|
||||
if !enableCookies && k == "Cookie" {
|
||||
continue
|
||||
}
|
||||
pbReqHeader[k] = &pb.Strings{Values: v}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应Header
|
||||
var pbResHeader = map[string]*pb.Strings{}
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
|
||||
for k, v := range this.writer.Header() {
|
||||
pbResHeader[k] = &pb.Strings{Values: v}
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableResponseHeaders {
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
|
||||
for k, v := range this.writer.Header() {
|
||||
pbResHeader[k] = &pb.Strings{Values: v}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package nodes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
@@ -9,8 +10,36 @@ import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// WebsocketResponseReader Websocket响应Reader
|
||||
type WebsocketResponseReader struct {
|
||||
rawReader io.Reader
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
|
||||
return &WebsocketResponseReader{
|
||||
rawReader: rawReader,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
|
||||
n, err = this.rawReader.Read(p)
|
||||
if n > 0 {
|
||||
if len(this.buf) == 0 {
|
||||
this.buf = make([]byte, n)
|
||||
copy(this.buf, p[:n])
|
||||
} else {
|
||||
this.buf = append(this.buf, p[:n]...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 处理Websocket请求
|
||||
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
|
||||
// 设置不缓存
|
||||
this.web.Cache = nil
|
||||
|
||||
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.addError(errors.New("websocket have not been enabled yet"))
|
||||
@@ -84,14 +113,20 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
|
||||
|
||||
go func() {
|
||||
// 读取第一个响应
|
||||
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
|
||||
if err != nil {
|
||||
var respReader = NewWebsocketResponseReader(originConn)
|
||||
resp, err := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
|
||||
if err != nil || resp == nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
this.processResponseHeaders(resp.Header, resp.StatusCode)
|
||||
this.writer.statusCode = resp.StatusCode
|
||||
|
||||
// 将响应写回客户端
|
||||
err = resp.Write(clientConn)
|
||||
@@ -105,6 +140,25 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
|
||||
return
|
||||
}
|
||||
|
||||
// 剩余已经从源站读取的内容
|
||||
var headerBytes = respReader.buf
|
||||
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
|
||||
if headerIndex > 0 {
|
||||
var leftBytes = headerBytes[headerIndex+4:]
|
||||
if len(leftBytes) > 0 {
|
||||
_, err = clientConn.Write(leftBytes)
|
||||
if err != nil {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
||||
|
||||
// 严格查找域名
|
||||
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
|
||||
group := this.Group
|
||||
var group = this.Group
|
||||
if group == nil {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func (this *HTTPListener) Serve() error {
|
||||
Handler: this,
|
||||
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
|
||||
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
|
||||
WriteTimeout: 1 * time.Hour, // TODO 改成可以配置
|
||||
WriteTimeout: 2 * time.Hour, // TODO 改成可以配置
|
||||
IdleTimeout: 75 * time.Second, // TODO 改成可以配置
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
@@ -175,6 +175,15 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
}
|
||||
}
|
||||
|
||||
// 检查用户
|
||||
if server != nil && server.UserId > 0 {
|
||||
if !SharedUserManager.CheckUserServersIsEnabled(server.UserId) {
|
||||
rawWriter.WriteHeader(http.StatusNotFound)
|
||||
_, _ = rawWriter.Write([]byte("The site owner is unavailable."))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 包装新请求对象
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
|
||||
@@ -92,7 +92,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 是否已达到流量限制
|
||||
if this.reachedTrafficLimit() {
|
||||
if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
|
||||
// 关闭连接
|
||||
tcpConn, ok := conn.(LingerConn)
|
||||
if ok {
|
||||
|
||||
@@ -170,6 +170,11 @@ func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, cm, clientAddr, err := listener.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if this.isClosed {
|
||||
|
||||
@@ -205,9 +205,7 @@ func (this *Node) Start() {
|
||||
|
||||
// 统计
|
||||
goman.New(func() {
|
||||
stats.SharedTrafficStatManager.Start(func() *nodeconfigs.NodeConfig {
|
||||
return sharedNodeConfig
|
||||
})
|
||||
stats.SharedTrafficStatManager.Start()
|
||||
})
|
||||
goman.New(func() {
|
||||
stats.SharedHTTPRequestStatManager.Start()
|
||||
@@ -430,6 +428,22 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
|
||||
}
|
||||
}
|
||||
}
|
||||
case "userServersStateChanged":
|
||||
if task.UserId > 0 {
|
||||
resp, err := rpcClient.UserRPC.CheckUserServersState(nodeCtx, &pb.CheckUserServersStateRequest{UserId: task.UserId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SharedUserManager.UpdateUserServersIsEnabled(task.UserId, resp.IsEnabled)
|
||||
|
||||
if resp.IsEnabled {
|
||||
err = this.syncUserServersConfig(task.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
|
||||
}
|
||||
@@ -615,6 +629,36 @@ func (this *Node) syncServerConfig(serverId int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 同步某个用户下的所有服务配置
|
||||
func (this *Node) syncUserServersConfig(userId int64) error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverConfigsResp, err := rpcClient.ServerRPC.ComposeAllUserServersConfig(rpcClient.Context(), &pb.ComposeAllUserServersConfigRequest{
|
||||
UserId: userId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(serverConfigsResp.ServersConfigJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
var serverConfigs = []*serverconfigs.ServerConfig{}
|
||||
err = json.Unmarshal(serverConfigsResp.ServersConfigJSON, &serverConfigs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
for _, config := range serverConfigs {
|
||||
this.updatingServerMap[config.Id] = config
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 启动同步计时器
|
||||
func (this *Node) startSyncTimer() {
|
||||
// TODO 这个时间间隔可以自行设置
|
||||
|
||||
@@ -195,8 +195,8 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
})
|
||||
|
||||
// 当前TeaWeb所在的fs
|
||||
rootFS := ""
|
||||
rootTotal := uint64(0)
|
||||
var rootFS = ""
|
||||
var rootTotal = uint64(0)
|
||||
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
|
||||
for _, p := range partitions {
|
||||
if p.Mountpoint == "/" {
|
||||
@@ -210,9 +210,9 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
}
|
||||
}
|
||||
|
||||
total := rootTotal
|
||||
totalUsage := uint64(0)
|
||||
maxUsage := float64(0)
|
||||
var total = rootTotal
|
||||
var totalUsage = uint64(0)
|
||||
var maxUsage = float64(0)
|
||||
for _, partition := range partitions {
|
||||
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
|
||||
continue
|
||||
@@ -252,7 +252,7 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
// 缓存空间
|
||||
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
|
||||
var result = []maps.Map{}
|
||||
cachePaths := caches.SharedManager.FindAllCachePaths()
|
||||
var cachePaths = caches.SharedManager.FindAllCachePaths()
|
||||
for _, path := range cachePaths {
|
||||
var stat unix.Statfs_t
|
||||
err := unix.Statfs(path, &stat)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package nodes
|
||||
|
||||
|
||||
49
internal/nodes/user_manager.go
Normal file
49
internal/nodes/user_manager.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var SharedUserManager = NewUserManager()
|
||||
|
||||
type User struct {
|
||||
ServersEnabled bool
|
||||
}
|
||||
|
||||
type UserManager struct {
|
||||
userMap map[int64]*User // id => *User
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewUserManager() *UserManager {
|
||||
return &UserManager{
|
||||
userMap: map[int64]*User{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *UserManager) UpdateUserServersIsEnabled(userId int64, isEnabled bool) {
|
||||
this.locker.Lock()
|
||||
u, ok := this.userMap[userId]
|
||||
if ok {
|
||||
u.ServersEnabled = isEnabled
|
||||
} else {
|
||||
u = &User{ServersEnabled: isEnabled}
|
||||
this.userMap[userId] = u
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *UserManager) CheckUserServersIsEnabled(userId int64) (isEnabled bool) {
|
||||
this.locker.RLock()
|
||||
u, ok := this.userMap[userId]
|
||||
if ok {
|
||||
isEnabled = u.ServersEnabled
|
||||
} else {
|
||||
isEnabled = true
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return
|
||||
}
|
||||
@@ -49,6 +49,7 @@ type RPCClient struct {
|
||||
FirewallRPC pb.FirewallServiceClient
|
||||
SSLCertRPC pb.SSLCertServiceClient
|
||||
ScriptRPC pb.ScriptServiceClient
|
||||
UserRPC pb.UserServiceClient
|
||||
}
|
||||
|
||||
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
@@ -81,6 +82,7 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
|
||||
client.FirewallRPC = pb.NewFirewallServiceClient(client)
|
||||
client.SSLCertRPC = pb.NewSSLCertServiceClient(client)
|
||||
client.ScriptRPC = pb.NewScriptServiceClient(client)
|
||||
client.UserRPC = pb.NewUserServiceClient(client)
|
||||
|
||||
err := client.init()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
@@ -42,6 +43,8 @@ type BandwidthStat struct {
|
||||
type BandwidthStatManager struct {
|
||||
m map[string]*BandwidthStat // key => *BandwidthStat
|
||||
|
||||
pbStats []*pb.ServerBandwidthStat
|
||||
|
||||
lastTime string // 上一次执行的时间
|
||||
|
||||
ticker *time.Ticker
|
||||
@@ -65,6 +68,12 @@ func (this *BandwidthStatManager) Start() {
|
||||
}
|
||||
|
||||
func (this *BandwidthStatManager) Loop() error {
|
||||
var regionId int64
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
regionId = nodeConfig.RegionId
|
||||
}
|
||||
|
||||
var now = time.Now()
|
||||
var day = timeutil.Format("Ymd", now)
|
||||
var currentTime = timeutil.FormatTime("Hi", now.Unix()/300*300)
|
||||
@@ -76,16 +85,29 @@ func (this *BandwidthStatManager) Loop() error {
|
||||
|
||||
var pbStats = []*pb.ServerBandwidthStat{}
|
||||
|
||||
// 历史未提交记录
|
||||
if len(this.pbStats) > 0 {
|
||||
var expiredTime = timeutil.FormatTime("Hi", time.Now().Unix()-1200) // 只保留20分钟
|
||||
|
||||
for _, stat := range this.pbStats {
|
||||
if stat.TimeAt > expiredTime {
|
||||
pbStats = append(pbStats, stat)
|
||||
}
|
||||
}
|
||||
this.pbStats = nil
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for key, stat := range this.m {
|
||||
if stat.Day < day || stat.TimeAt < currentTime {
|
||||
pbStats = append(pbStats, &pb.ServerBandwidthStat{
|
||||
Id: 0,
|
||||
UserId: stat.UserId,
|
||||
ServerId: stat.ServerId,
|
||||
Day: stat.Day,
|
||||
TimeAt: stat.TimeAt,
|
||||
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
|
||||
Id: 0,
|
||||
UserId: stat.UserId,
|
||||
ServerId: stat.ServerId,
|
||||
Day: stat.Day,
|
||||
TimeAt: stat.TimeAt,
|
||||
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
|
||||
NodeRegionId: regionId,
|
||||
})
|
||||
delete(this.m, key)
|
||||
}
|
||||
@@ -100,6 +122,8 @@ func (this *BandwidthStatManager) Loop() error {
|
||||
}
|
||||
_, err = rpcClient.ServerBandwidthStatRPC.UploadServerBandwidthStats(rpcClient.Context(), &pb.UploadServerBandwidthStatsRequest{ServerBandwidthStats: pbStats})
|
||||
if err != nil {
|
||||
this.pbStats = pbStats
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,19 +31,33 @@ type TrafficItem struct {
|
||||
CheckingTrafficLimit bool
|
||||
}
|
||||
|
||||
func (this *TrafficItem) Add(anotherItem *TrafficItem) {
|
||||
this.Bytes += anotherItem.Bytes
|
||||
this.CachedBytes += anotherItem.CachedBytes
|
||||
this.CountRequests += anotherItem.CountRequests
|
||||
this.CountCachedRequests += anotherItem.CountCachedRequests
|
||||
this.CountAttackRequests += anotherItem.CountAttackRequests
|
||||
this.AttackBytes += anotherItem.AttackBytes
|
||||
}
|
||||
|
||||
const trafficStatsMaxLife = 1200 // 最大只保存20分钟内的数据
|
||||
|
||||
// TrafficStatManager 区域流量统计
|
||||
type TrafficStatManager struct {
|
||||
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
|
||||
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
|
||||
locker sync.Mutex
|
||||
configFunc func() *nodeconfigs.NodeConfig
|
||||
|
||||
pbItems []*pb.ServerDailyStat
|
||||
pbDomainItems []*pb.UploadServerDailyStatsRequest_DomainStat
|
||||
|
||||
locker sync.Mutex
|
||||
|
||||
totalRequests int64
|
||||
}
|
||||
|
||||
// NewTrafficStatManager 获取新对象
|
||||
func NewTrafficStatManager() *TrafficStatManager {
|
||||
manager := &TrafficStatManager{
|
||||
var manager = &TrafficStatManager{
|
||||
itemMap: map[string]*TrafficItem{},
|
||||
domainsMap: map[string]*TrafficItem{},
|
||||
}
|
||||
@@ -52,9 +66,7 @@ func NewTrafficStatManager() *TrafficStatManager {
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) {
|
||||
this.configFunc = configFunc
|
||||
|
||||
func (this *TrafficStatManager) Start() {
|
||||
// 上传请求总数
|
||||
var monitorTicker = time.NewTicker(1 * time.Minute)
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
@@ -70,7 +82,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
|
||||
})
|
||||
|
||||
// 上传统计数据
|
||||
duration := 5 * time.Minute
|
||||
var duration = 5 * time.Minute
|
||||
if Tea.IsTesting() {
|
||||
// 测试环境缩短上传时间,方便我们调试
|
||||
duration = 30 * time.Second
|
||||
@@ -143,9 +155,10 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
|
||||
|
||||
// Upload 上传流量
|
||||
func (this *TrafficStatManager) Upload() error {
|
||||
var config = this.configFunc()
|
||||
if config == nil {
|
||||
return nil
|
||||
var regionId int64
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
regionId = nodeConfig.RegionId
|
||||
}
|
||||
|
||||
client, err := rpc.SharedRPC()
|
||||
@@ -154,10 +167,14 @@ func (this *TrafficStatManager) Upload() error {
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
|
||||
var itemMap = this.itemMap
|
||||
var domainMap = this.domainsMap
|
||||
|
||||
// reset
|
||||
this.itemMap = map[string]*TrafficItem{}
|
||||
this.domainsMap = map[string]*TrafficItem{}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
// 服务统计
|
||||
@@ -174,7 +191,7 @@ func (this *TrafficStatManager) Upload() error {
|
||||
|
||||
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
|
||||
ServerId: serverId,
|
||||
RegionId: config.RegionId,
|
||||
NodeRegionId: regionId,
|
||||
Bytes: item.Bytes,
|
||||
CachedBytes: item.CachedBytes,
|
||||
CountRequests: item.CountRequests,
|
||||
@@ -186,9 +203,6 @@ func (this *TrafficStatManager) Upload() error {
|
||||
CreatedAt: timestamp,
|
||||
})
|
||||
}
|
||||
if len(pbServerStats) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 域名统计
|
||||
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
|
||||
@@ -210,9 +224,40 @@ func (this *TrafficStatManager) Upload() error {
|
||||
})
|
||||
}
|
||||
|
||||
// 历史未提交记录
|
||||
if len(this.pbItems) > 0 || len(this.pbDomainItems) > 0 {
|
||||
var expiredAt = time.Now().Unix() - 1200 // 只保留20分钟
|
||||
|
||||
for _, item := range this.pbItems {
|
||||
if item.CreatedAt > expiredAt {
|
||||
pbServerStats = append(pbServerStats, item)
|
||||
}
|
||||
}
|
||||
this.pbItems = nil
|
||||
|
||||
for _, item := range this.pbDomainItems {
|
||||
if item.CreatedAt > expiredAt {
|
||||
pbDomainStats = append(pbDomainStats, item)
|
||||
}
|
||||
}
|
||||
this.pbDomainItems = nil
|
||||
}
|
||||
|
||||
if len(pbServerStats) == 0 && len(pbDomainStats) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = client.ServerDailyStatRPC.UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{
|
||||
Stats: pbServerStats,
|
||||
DomainStats: pbDomainStats,
|
||||
})
|
||||
return err
|
||||
if err != nil {
|
||||
// 加回历史记录
|
||||
this.pbItems = pbServerStats
|
||||
this.pbDomainItems = pbDomainStats
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ func init() {
|
||||
}
|
||||
|
||||
type ClockManager struct {
|
||||
lastFailAt int64
|
||||
}
|
||||
|
||||
func NewClockManager() *ClockManager {
|
||||
@@ -51,7 +52,13 @@ func (this *ClockManager) Start() {
|
||||
for range ticker.C {
|
||||
err := this.Sync()
|
||||
if err != nil {
|
||||
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
|
||||
var currentTimestamp = time.Now().Unix()
|
||||
|
||||
// 每天只提醒一次错误
|
||||
if currentTimestamp-this.lastFailAt > 86400 {
|
||||
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
|
||||
this.lastFailAt = currentTimestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,7 +125,7 @@ func (this *ClockManager) syncNtpdate(ntpdate string, server string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 参考自:https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
|
||||
// ReadServer 参考自:https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
|
||||
func (this *ClockManager) ReadServer(server string) (time.Time, error) {
|
||||
conn, err := net.Dial("udp", server+":123")
|
||||
if err != nil {
|
||||
|
||||
28
internal/utils/readers/reader_print.go
Normal file
28
internal/utils/readers/reader_print.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package readers
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
type PrintReader struct {
|
||||
rawReader io.Reader
|
||||
tag string
|
||||
}
|
||||
|
||||
func NewPrintReader(rawReader io.Reader, tag string) io.Reader {
|
||||
return &PrintReader{
|
||||
rawReader: rawReader,
|
||||
tag: tag,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *PrintReader) Read(p []byte) (n int, err error) {
|
||||
n, err = this.rawReader.Read(p)
|
||||
if n > 0 {
|
||||
log.Println("[" + this.tag + "]" + string(p[:n]))
|
||||
}
|
||||
return
|
||||
}
|
||||
28
internal/utils/writers/writer_print.go
Normal file
28
internal/utils/writers/writer_print.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package writers
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
type PrintWriter struct {
|
||||
rawWriter io.Writer
|
||||
tag string
|
||||
}
|
||||
|
||||
func NewPrintWriter(rawWriter io.Writer, tag string) io.Writer {
|
||||
return &PrintWriter{
|
||||
rawWriter: rawWriter,
|
||||
tag: tag,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *PrintWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = this.rawWriter.Write(p)
|
||||
if n > 0 {
|
||||
log.Println("[" + this.tag + "]" + string(p[:n]))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type Checkpoint struct {
|
||||
priority int
|
||||
}
|
||||
|
||||
func (this *Checkpoint) Init() {
|
||||
@@ -36,6 +37,14 @@ func (this *Checkpoint) Stop() {
|
||||
|
||||
}
|
||||
|
||||
func (this *Checkpoint) SetPriority(priority int) {
|
||||
this.priority = priority
|
||||
}
|
||||
|
||||
func (this *Checkpoint) Priority() int {
|
||||
return this.priority
|
||||
}
|
||||
|
||||
func (this *Checkpoint) RequestBodyIsEmpty(req requests.Request) bool {
|
||||
if req.WAFRaw().ContentLength == 0 {
|
||||
return true
|
||||
|
||||
@@ -7,4 +7,5 @@ type CheckpointDefinition struct {
|
||||
Prefix string
|
||||
HasParams bool // has sub params
|
||||
Instance CheckpointInterface
|
||||
Priority int
|
||||
}
|
||||
|
||||
@@ -33,4 +33,10 @@ type CheckpointInterface interface {
|
||||
|
||||
// Stop stop
|
||||
Stop()
|
||||
|
||||
// SetPriority set priority
|
||||
SetPriority(priority int)
|
||||
|
||||
// get priority
|
||||
Priority() int
|
||||
}
|
||||
|
||||
@@ -41,19 +41,45 @@ func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, pa
|
||||
return
|
||||
}
|
||||
|
||||
var domains = options.GetSlice("allowDomains")
|
||||
var domainStrings = []string{}
|
||||
for _, domain := range domains {
|
||||
domainStrings = append(domainStrings, types.String(domain))
|
||||
// allow domains
|
||||
var allowDomains = options.GetSlice("allowDomains")
|
||||
var allowDomainStrings = []string{}
|
||||
for _, domain := range allowDomains {
|
||||
allowDomainStrings = append(allowDomainStrings, types.String(domain))
|
||||
}
|
||||
|
||||
if len(domainStrings) == 0 {
|
||||
// deny domains
|
||||
var denyDomains = options.GetSlice("denyDomains")
|
||||
var denyDomainStrings = []string{}
|
||||
for _, domain := range denyDomains {
|
||||
denyDomainStrings = append(denyDomainStrings, types.String(domain))
|
||||
}
|
||||
|
||||
if len(allowDomainStrings) == 0 {
|
||||
if len(denyDomainStrings) > 0 {
|
||||
if configutils.MatchDomains(denyDomainStrings, host) {
|
||||
value = 0
|
||||
} else {
|
||||
value = 1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
value = 0
|
||||
return
|
||||
}
|
||||
|
||||
if configutils.MatchDomains(domainStrings, host) {
|
||||
if configutils.MatchDomains(allowDomainStrings, host) {
|
||||
if len(denyDomainStrings) > 0 {
|
||||
if configutils.MatchDomains(denyDomainStrings, host) {
|
||||
value = 0
|
||||
} else {
|
||||
value = 1
|
||||
}
|
||||
return
|
||||
}
|
||||
value = 1
|
||||
return
|
||||
} else {
|
||||
value = 0
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "通用Header比如Cache-Control、Accept之类的长度限制,防止缓冲区溢出攻击",
|
||||
HasParams: false,
|
||||
Instance: new(RequestGeneralHeaderLengthCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "客户端地址(IP)",
|
||||
@@ -15,6 +16,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址,比如192.168.1.100",
|
||||
HasParams: false,
|
||||
Instance: new(RequestRemoteAddrCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "客户端源地址(IP)",
|
||||
@@ -22,6 +24,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "直接连接的客户端地址,比如192.168.1.100",
|
||||
HasParams: false,
|
||||
Instance: new(RequestRawRemoteAddrCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "客户端端口",
|
||||
@@ -29,6 +32,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "直接连接的客户端地址端口",
|
||||
HasParams: false,
|
||||
Instance: new(RequestRemotePortCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "客户端用户名",
|
||||
@@ -36,6 +40,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "通过BasicAuth登录的客户端用户名",
|
||||
HasParams: false,
|
||||
Instance: new(RequestRemoteUserCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求URI",
|
||||
@@ -43,6 +48,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "包含URL参数的请求URI,类似于 /hello/world?lang=go",
|
||||
HasParams: false,
|
||||
Instance: new(RequestURICheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求路径",
|
||||
@@ -50,6 +56,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "不包含URL参数的请求路径,类似于 /hello/world",
|
||||
HasParams: false,
|
||||
Instance: new(RequestPathCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求URL",
|
||||
@@ -57,6 +64,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "完整的请求URL,包含协议、域名、请求路径、参数等,类似于 https://example.com/hello?name=lily",
|
||||
HasParams: false,
|
||||
Instance: new(RequestURLCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求内容长度",
|
||||
@@ -64,6 +72,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "请求Header中的Content-Length",
|
||||
HasParams: false,
|
||||
Instance: new(RequestLengthCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求体内容",
|
||||
@@ -71,6 +80,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "通常在POST或者PUT等操作时会附带请求体,最大限制32M",
|
||||
HasParams: false,
|
||||
Instance: new(RequestBodyCheckpoint),
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
Name: "请求URI和请求体组合",
|
||||
@@ -78,6 +88,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "${requestURI}和${requestBody}组合",
|
||||
HasParams: false,
|
||||
Instance: new(RequestAllCheckpoint),
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
Name: "请求表单参数",
|
||||
@@ -85,6 +96,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "获取POST或者其他方法发送的表单参数,最大请求体限制32M",
|
||||
HasParams: true,
|
||||
Instance: new(RequestFormArgCheckpoint),
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
Name: "上传文件",
|
||||
@@ -92,6 +104,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "获取POST上传的文件信息,最大请求体限制32M",
|
||||
HasParams: true,
|
||||
Instance: new(RequestUploadCheckpoint),
|
||||
Priority: 20,
|
||||
},
|
||||
{
|
||||
Name: "请求JSON参数",
|
||||
@@ -99,6 +112,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "获取POST或者其他方法发送的JSON,最大请求体限制32M,使用点(.)符号表示多级数据",
|
||||
HasParams: true,
|
||||
Instance: new(RequestJSONArgCheckpoint),
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
Name: "请求方法",
|
||||
@@ -106,6 +120,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如GET、POST",
|
||||
HasParams: false,
|
||||
Instance: new(RequestMethodCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求协议",
|
||||
@@ -113,6 +128,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如http或https",
|
||||
HasParams: false,
|
||||
Instance: new(RequestSchemeCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "HTTP协议版本",
|
||||
@@ -120,6 +136,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如HTTP/1.1",
|
||||
HasParams: false,
|
||||
Instance: new(RequestProtoCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "主机名",
|
||||
@@ -127,6 +144,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如teaos.cn",
|
||||
HasParams: false,
|
||||
Instance: new(RequestHostCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "请求来源URL",
|
||||
@@ -134,6 +152,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "请求Header中的Referer值",
|
||||
HasParams: false,
|
||||
Instance: new(RequestRefererCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "客户端信息",
|
||||
@@ -141,6 +160,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103",
|
||||
HasParams: false,
|
||||
Instance: new(RequestUserAgentCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "内容类型",
|
||||
@@ -148,6 +168,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "请求Header的Content-Type",
|
||||
HasParams: false,
|
||||
Instance: new(RequestContentTypeCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "所有cookie组合字符串",
|
||||
@@ -155,6 +176,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如sid=IxZVPFhE&city=beijing&uid=18237",
|
||||
HasParams: false,
|
||||
Instance: new(RequestCookiesCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "单个cookie值",
|
||||
@@ -162,6 +184,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "单个cookie值",
|
||||
HasParams: true,
|
||||
Instance: new(RequestCookieCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "所有URL参数组合",
|
||||
@@ -169,6 +192,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "比如name=lu&age=20",
|
||||
HasParams: false,
|
||||
Instance: new(RequestArgsCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "单个URL参数值",
|
||||
@@ -176,6 +200,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "单个URL参数值",
|
||||
HasParams: true,
|
||||
Instance: new(RequestArgCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "所有Header信息",
|
||||
@@ -183,6 +208,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "使用\\n隔开的Header信息字符串",
|
||||
HasParams: false,
|
||||
Instance: new(RequestHeadersCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "单个Header值",
|
||||
@@ -190,6 +216,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "单个Header值",
|
||||
HasParams: true,
|
||||
Instance: new(RequestHeaderCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "国家/地区名称",
|
||||
@@ -197,6 +224,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "国家/地区名称",
|
||||
HasParams: false,
|
||||
Instance: new(RequestGeoCountryNameCheckpoint),
|
||||
Priority: 90,
|
||||
},
|
||||
{
|
||||
Name: "省份名称",
|
||||
@@ -204,6 +232,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "中国省份名称",
|
||||
HasParams: false,
|
||||
Instance: new(RequestGeoProvinceNameCheckpoint),
|
||||
Priority: 90,
|
||||
},
|
||||
{
|
||||
Name: "城市名称",
|
||||
@@ -211,6 +240,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "中国城市名称",
|
||||
HasParams: false,
|
||||
Instance: new(RequestGeoCityNameCheckpoint),
|
||||
Priority: 90,
|
||||
},
|
||||
{
|
||||
Name: "ISP名称",
|
||||
@@ -218,6 +248,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "ISP名称",
|
||||
HasParams: false,
|
||||
Instance: new(RequestISPNameCheckpoint),
|
||||
Priority: 90,
|
||||
},
|
||||
{
|
||||
Name: "CC统计(旧)",
|
||||
@@ -225,6 +256,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "统计某段时间段内的请求信息",
|
||||
HasParams: true,
|
||||
Instance: new(CCCheckpoint),
|
||||
Priority: 10,
|
||||
},
|
||||
{
|
||||
Name: "CC统计(新)",
|
||||
@@ -232,6 +264,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "统计某段时间段内的请求信息",
|
||||
HasParams: true,
|
||||
Instance: new(CC2Checkpoint),
|
||||
Priority: 10,
|
||||
},
|
||||
{
|
||||
Name: "防盗链",
|
||||
@@ -239,6 +272,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "阻止一些域名访问引用本站资源",
|
||||
HasParams: true,
|
||||
Instance: new(RequestRefererBlockCheckpoint),
|
||||
Priority: 20,
|
||||
},
|
||||
{
|
||||
Name: "通用响应Header长度限制",
|
||||
@@ -246,6 +280,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "通用Header比如Cache-Control、Accept之类的长度限制,防止缓冲区溢出攻击",
|
||||
HasParams: false,
|
||||
Instance: new(ResponseGeneralHeaderLengthCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "响应状态码",
|
||||
@@ -253,6 +288,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "响应状态码,比如200、404、500",
|
||||
HasParams: false,
|
||||
Instance: new(ResponseStatusCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "响应Header",
|
||||
@@ -260,6 +296,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "响应Header值",
|
||||
HasParams: true,
|
||||
Instance: new(ResponseHeaderCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
{
|
||||
Name: "响应内容",
|
||||
@@ -267,6 +304,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "响应内容字符串",
|
||||
HasParams: false,
|
||||
Instance: new(ResponseBodyCheckpoint),
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
Name: "响应内容长度",
|
||||
@@ -274,6 +312,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
Description: "响应内容长度,通过响应的Header Content-Length获取",
|
||||
HasParams: false,
|
||||
Instance: new(ResponseBytesSentCheckpoint),
|
||||
Priority: 100,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -281,6 +320,7 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
func FindCheckpoint(prefix string) CheckpointInterface {
|
||||
for _, def := range AllCheckpoints {
|
||||
if def.Prefix == prefix {
|
||||
def.Instance.SetPriority(def.Priority)
|
||||
return def.Instance
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ type Rule struct {
|
||||
Value string `yaml:"value" json:"value"` // compared value
|
||||
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
|
||||
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
|
||||
Priority int `yaml:"priority" json:"priority"`
|
||||
|
||||
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
|
||||
|
||||
@@ -132,9 +133,9 @@ func (this *Rule) Init() error {
|
||||
}
|
||||
|
||||
if singleParamRegexp.MatchString(this.Param) {
|
||||
param := this.Param[2 : len(this.Param)-1]
|
||||
pieces := strings.SplitN(param, ".", 2)
|
||||
prefix := pieces[0]
|
||||
var param = this.Param[2 : len(this.Param)-1]
|
||||
var pieces = strings.SplitN(param, ".", 2)
|
||||
var prefix = pieces[0]
|
||||
if len(pieces) == 1 {
|
||||
this.singleParam = ""
|
||||
} else {
|
||||
@@ -142,18 +143,20 @@ func (this *Rule) Init() error {
|
||||
}
|
||||
|
||||
if this.checkpointFinder != nil {
|
||||
checkpoint := this.checkpointFinder(prefix)
|
||||
var checkpoint = this.checkpointFinder(prefix)
|
||||
if checkpoint == nil {
|
||||
return errors.New("no check point '" + prefix + "' found")
|
||||
}
|
||||
this.singleCheckpoint = checkpoint
|
||||
this.Priority = checkpoint.Priority()
|
||||
} else {
|
||||
checkpoint := checkpoints.FindCheckpoint(prefix)
|
||||
var checkpoint = checkpoints.FindCheckpoint(prefix)
|
||||
if checkpoint == nil {
|
||||
return errors.New("no check point '" + prefix + "' found")
|
||||
}
|
||||
checkpoint.Init()
|
||||
this.singleCheckpoint = checkpoint
|
||||
this.Priority = checkpoint.Priority()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -162,22 +165,24 @@ func (this *Rule) Init() error {
|
||||
this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{}
|
||||
var err error = nil
|
||||
configutils.ParseVariables(this.Param, func(varName string) (value string) {
|
||||
pieces := strings.SplitN(varName, ".", 2)
|
||||
prefix := pieces[0]
|
||||
var pieces = strings.SplitN(varName, ".", 2)
|
||||
var prefix = pieces[0]
|
||||
if this.checkpointFinder != nil {
|
||||
checkpoint := this.checkpointFinder(prefix)
|
||||
var checkpoint = this.checkpointFinder(prefix)
|
||||
if checkpoint == nil {
|
||||
err = errors.New("no check point '" + prefix + "' found")
|
||||
} else {
|
||||
this.multipleCheckpoints[prefix] = checkpoint
|
||||
this.Priority = checkpoint.Priority()
|
||||
}
|
||||
} else {
|
||||
checkpoint := checkpoints.FindCheckpoint(prefix)
|
||||
var checkpoint = checkpoints.FindCheckpoint(prefix)
|
||||
if checkpoint == nil {
|
||||
err = errors.New("no check point '" + prefix + "' found")
|
||||
} else {
|
||||
checkpoint.Init()
|
||||
this.multipleCheckpoints[prefix] = checkpoint
|
||||
this.Priority = checkpoint.Priority()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
@@ -52,6 +52,11 @@ func (this *RuleSet) Init(waf *WAF) error {
|
||||
return errors.New("init rule '" + rule.Param + " " + rule.Operator + " " + types.String(rule.Value) + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// sort by priority
|
||||
sort.Slice(this.Rules, func(i, j int) bool {
|
||||
return this.Rules[i].Priority > this.Rules[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
// action codes
|
||||
|
||||
@@ -73,6 +73,7 @@ func (this *WAF) Init() (resultErrors []error) {
|
||||
for _, def := range checkpoints.AllCheckpoints {
|
||||
instance := reflect.New(reflect.Indirect(reflect.ValueOf(def.Instance)).Type()).Interface().(checkpoints.CheckpointInterface)
|
||||
instance.Init()
|
||||
instance.SetPriority(def.Priority)
|
||||
this.checkpointsMap[def.Prefix] = instance
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user