Compare commits

...

20 Commits

Author SHA1 Message Date
刘祥超
09dcf0d712 集群全局服务配置中增加多个访问日志相关选项 2022-10-26 17:51:16 +08:00
刘祥超
60aebd9306 URL跳转中增加域名跳转、端口跳转 2022-10-26 16:14:37 +08:00
刘祥超
04191d04d3 节点设置中增加“通过IP名单”选项 2022-10-26 10:42:16 +08:00
刘祥超
b80a5c525f 节点缓存目录所在磁盘空间不足时(<5G),暂停缓存写入,同时启动LFU清理 2022-10-25 15:14:28 +08:00
刘祥超
265c1e5312 WAF参数定义增加优先级,可以让“轻”任务优先执行 2022-10-24 17:57:07 +08:00
刘祥超
2723f705b6 修复在iptables中加入ipv6的错误 2022-10-24 16:37:54 +08:00
刘祥超
b4cddd6341 集群服务设置--访问日志中可以设置是否只记录通用Header 2022-10-24 14:39:18 +08:00
刘祥超
5636a81d48 防盗链功能增加禁止的来源域名 2022-10-24 10:21:23 +08:00
刘祥超
d8059960de 文件缓存索引表取消UNIQUE索引,尽可能避免 sqlite malformed 错误 2022-10-23 20:45:41 +08:00
刘祥超
17af4064af 带宽和流量提交失败时,将在一定时间内重试 2022-10-23 19:41:21 +08:00
刘祥超
15f37d2c93 优化用户服务整体启用和禁用 2022-10-23 16:21:11 +08:00
刘祥超
6dc3aa8cb7 单请求写入时间从1个小时增加到2个小时 2022-10-23 09:52:50 +08:00
刘祥超
900cccf2f1 修复源站Websocket源站读取失败导致的异常错误 2022-10-18 19:43:53 +08:00
刘祥超
1fec88dfc6 优化代码 2022-10-14 15:00:05 +08:00
刘祥超
7da9363336 上传带宽信息时附带区域ID信息 2022-10-11 18:57:35 +08:00
刘祥超
d82e633bba 时钟同步程序每天只提示一次警告信息 2022-10-11 11:31:00 +08:00
刘祥超
b363bbaafd 版本修改为0.5.6 2022-10-01 08:50:12 +08:00
刘祥超
92a20e3c9a 修复Websocket无法正常交互的问题 2022-09-30 16:34:21 +08:00
刘祥超
5742dfb263 修复Websocket响应可能被缓存的问题 2022-09-30 14:55:42 +08:00
刘祥超
0ae63511d5 版本调整为v0.5.5 2022-09-28 18:57:27 +08:00
32 changed files with 696 additions and 144 deletions

View File

@@ -473,6 +473,7 @@ func (this *FileListDB) initTables(times int) error {
{
// expiredAt - 过期时间,用来判断有无过期
// staleAt - 过时缓存最大时间,用来清理缓存
// 不对 hash 增加 unique 参数,是尽可能避免产生 malformed 错误
_, err := this.writeDB.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemsTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"hash" varchar(32),
@@ -498,7 +499,7 @@ ON "` + this.itemsTableName + `" (
"staleAt" ASC
);
CREATE UNIQUE INDEX IF NOT EXISTS "hash"
CREATE INDEX IF NOT EXISTS "hash"
ON "` + this.itemsTableName + `" (
"hash" ASC
);

View File

@@ -21,6 +21,7 @@ import (
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"golang.org/x/sys/unix"
"golang.org/x/text/language"
"golang.org/x/text/message"
"math"
@@ -58,6 +59,7 @@ const (
HotItemLifeSeconds int64 = 3600 // 热点数据生命周期
FileToMemoryMaxSize = 32 * sizes.M // 可以从文件写入到内存的最大文件尺寸
FileTmpSuffix = ".tmp"
MinDiskSpace = 5 << 30 // 当前磁盘最小剩余空间
)
var sharedWritingFileKeyMap = map[string]zero.Zero{} // key => bool
@@ -90,6 +92,8 @@ type FileStorage struct {
ignoreKeys *setutils.FixedSet
openFileCache *OpenFileCache
diskIsFull bool
}
func NewFileStorage(policy *serverconfigs.HTTPCachePolicy) *FileStorage {
@@ -287,6 +291,9 @@ func (this *FileStorage) Init() error {
// open file cache
this.initOpenFileCache()
// 检查磁盘空间
this.checkDiskSpace()
return nil
}
@@ -397,6 +404,11 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
return nil, ErrWritingUnavailable
}
// 当前磁盘可用容量是否严重不足
if this.diskIsFull {
return nil, NewCapacityError("the disk is full")
}
// 是否已忽略
if this.ignoreKeys.Has(key) {
return nil, ErrEntityTooLarge
@@ -938,18 +950,25 @@ func (this *FileStorage) initList() error {
// 清理任务
func (this *FileStorage) purgeLoop() {
// 检查磁盘剩余空间
this.checkDiskSpace()
// 计算是否应该开启LFU清理
var capacityBytes = this.policy.CapacityBytes()
var startLFU = false
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
var lfuFreePercent = this.policy.PersistenceLFUFreePercent
if lfuFreePercent <= 0 {
lfuFreePercent = 5
}
if capacityBytes > 0 {
if lfuFreePercent < 100 {
if usedPercent >= 100-lfuFreePercent {
startLFU = true
if this.diskIsFull {
startLFU = true
} else {
var usedPercent = float32(this.TotalDiskSize()*100) / float32(capacityBytes)
if capacityBytes > 0 {
if lfuFreePercent < 100 {
if usedPercent >= 100-lfuFreePercent {
startLFU = true
}
}
}
}
@@ -1327,3 +1346,15 @@ func (this *FileStorage) runMemoryStorageSafety(f func(memoryStorage *MemoryStor
f(memoryStorage)
}
}
// 检查磁盘剩余空间
func (this *FileStorage) checkDiskSpace() {
if this.options != nil && len(this.options.Dir) > 0 {
var stat unix.Statfs_t
err := unix.Statfs(this.options.Dir, &stat)
if err == nil {
var availableBytes = stat.Bavail * uint64(stat.Bsize)
this.diskIsFull = availableBytes < MinDiskSpace
}
}
}

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.5.4"
Version = "0.5.6"
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -4,6 +4,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"os/exec"
"runtime"
@@ -74,10 +75,16 @@ func (this *IPTablesAction) runAction(action string, listType IPListType, item *
}
func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
// 暂时不支持ipv6
// TODO 将来支持ipv6
if utils.IsIPv6(item.IpFrom) {
return nil
}
if item.Type == "all" {
return nil
}
path := this.config.Path
var path = this.config.Path
var err error
if len(path) == 0 {
path, err = exec.LookPath("iptables")
@@ -88,6 +95,7 @@ func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType
this.iptablesNotFound = true
return err
}
this.config.Path = path
}
iptablesAction := ""
switch action {

View File

@@ -1,6 +1,7 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
@@ -141,6 +142,12 @@ func (this *IPListManager) init() {
}
func (this *IPListManager) loop() error {
// 是否同步IP名单
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && !nodeConfig.EnableIPLists {
return nil
}
for {
hasNext, err := this.fetch()
if err != nil {

View File

@@ -1603,9 +1603,25 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
header.Del(k)
k = strings.ReplaceAll(k, "-Websocket-", "-WebSocket-")
header[k] = v
} else if k == "Www-Authenticate" {
} else if strings.HasPrefix(k, "Sec-Ch") {
header.Del(k)
header["WWW-Authenticate"] = v
k = strings.ReplaceAll(k, "Sec-Ch-Ua", "Sec-CH-UA")
header[k] = v
} else {
switch k {
case "Www-Authenticate":
header.Del(k)
header["WWW-Authenticate"] = v
case "A-Im":
header.Del(k)
header["A-IM"] = v
case "Content-Md5":
header.Del(k)
header["Content-MD5"] = v
case "Sec-Gpc":
header.Del(k)
header["Content-GPC"] = v
}
}
}
}

View File

@@ -19,6 +19,11 @@ import (
// 读取缓存
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 需要动态Upgrade的不缓存
if len(this.RawReq.Header.Get("Upgrade")) > 0 {
return
}
this.cacheCanTryStale = false
var cachePolicy = this.ReqServer.HTTPCachePolicy

View File

@@ -1,7 +1,11 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"strconv"
"strings"
@@ -13,7 +17,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if this.web.MergeSlashes {
urlPath = utils.CleanPath(urlPath)
}
fullURL := this.requestScheme() + "://" + this.ReqHost + urlPath
var fullURL = this.requestScheme() + "://" + this.ReqHost + urlPath
for _, u := range this.web.HostRedirects {
if !u.IsOn {
continue
@@ -21,11 +25,50 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if !u.MatchRequest(this.Format) {
continue
}
if u.MatchPrefix { // 匹配前缀
if strings.HasPrefix(fullURL, u.BeforeURL) {
afterURL := u.AfterURL
if u.KeepRequestURI {
afterURL += this.RawReq.URL.RequestURI()
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
if u.MatchPrefix { // 匹配前缀
if strings.HasPrefix(fullURL, u.BeforeURL) {
afterURL := u.AfterURL
if u.KeepRequestURI {
afterURL += this.RawReq.URL.RequestURI()
}
// 前后是否一致
if fullURL == afterURL {
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
} else if u.MatchRegexp { // 正则匹配
var reg = u.BeforeURLRegexp()
if reg == nil {
continue
}
var matches = reg.FindStringSubmatch(fullURL)
if len(matches) == 0 {
continue
}
var afterURL = u.AfterURL
for i, match := range matches {
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
}
var subNames = reg.SubexpNames()
if len(subNames) > 0 {
for _, subName := range subNames {
if len(subName) > 0 {
index := reg.SubexpIndex(subName)
if index > -1 {
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
}
}
}
}
// 前后是否一致
@@ -33,69 +76,6 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
}
} else if u.MatchRegexp { // 正则匹配
reg := u.BeforeURLRegexp()
if reg == nil {
continue
}
matches := reg.FindStringSubmatch(fullURL)
if len(matches) == 0 {
continue
}
afterURL := u.AfterURL
for i, match := range matches {
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
}
subNames := reg.SubexpNames()
if len(subNames) > 0 {
for _, subName := range subNames {
if len(subName) > 0 {
index := reg.SubexpIndex(subName)
if index > -1 {
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
}
}
}
}
// 前后是否一致
if fullURL == afterURL {
return false
}
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
} else { // 精准匹配
if fullURL == u.RealBeforeURL() {
// 前后是否一致
if fullURL == u.AfterURL {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
@@ -104,12 +84,110 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
} else { // 精准匹配
if fullURL == u.RealBeforeURL() {
// 前后是否一致
if fullURL == u.AfterURL {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypeDomain {
if len(u.DomainAfter) == 0 {
continue
}
// 如果跳转前后域名一致,则终止
if u.DomainAfter == this.ReqHost {
return false
}
var scheme = u.DomainAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
if u.DomainsAll || configutils.MatchDomains(u.DomainsBefore, this.ReqHost) {
var afterURL = scheme + "://" + u.DomainAfter + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
if u.PortAfter <= 0 {
continue
}
var scheme = u.PortAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
reqHost, reqPort, _ := net.SplitHostPort(this.ReqHost)
if len(reqHost) == 0 {
reqHost = this.ReqHost
}
if len(reqPort) == 0 {
switch this.requestScheme() {
case "http":
reqPort = "80"
case "https":
reqPort = "443"
}
}
// 如果跳转前后端口一致,则终止
if reqPort == types.String(u.PortAfter) {
return false
}
var containsPort = false
if u.PortsAll {
containsPort = true
} else {
containsPort = u.ContainsPort(types.Int(reqPort))
}
if containsPort {
var newReqHost = reqHost
if !((scheme == "http" && u.PortAfter == 80) || (scheme == "https" && u.PortAfter == 443)) {
newReqHost += ":" + types.String(u.PortAfter)
}
var afterURL = scheme + "://" + newReqHost + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
if u.Status <= 0 {
u.Status = http.StatusTemporaryRedirect
}
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
return true
}
}

View File

@@ -51,27 +51,46 @@ func (this *HTTPRequest) log() {
addr = addr[:index]
}
var serverGlobalConfig = this.nodeConfig.GlobalServerConfig
// 请求Cookie
var cookies = map[string]string{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
for _, cookie := range this.RawReq.Cookies() {
cookies[cookie.Name] = cookie.Value
var enableCookies = false
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableCookies {
enableCookies = true
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
for _, cookie := range this.RawReq.Cookies() {
cookies[cookie.Name] = cookie.Value
}
}
}
// 请求Header
var pbReqHeader = map[string]*pb.Strings{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
for k, v := range this.RawReq.Header {
pbReqHeader[k] = &pb.Strings{Values: v}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableRequestHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
// 是否只记录通用Header
var commonHeadersOnly = serverGlobalConfig != nil && serverGlobalConfig.HTTPAccessLog.CommonRequestHeadersOnly
for k, v := range this.RawReq.Header {
if commonHeadersOnly && !serverconfigs.IsCommonRequestHeader(k) {
continue
}
if !enableCookies && k == "Cookie" {
continue
}
pbReqHeader[k] = &pb.Strings{Values: v}
}
}
}
// 响应Header
var pbResHeader = map[string]*pb.Strings{}
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
for k, v := range this.writer.Header() {
pbResHeader[k] = &pb.Strings{Values: v}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableResponseHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
for k, v := range this.writer.Header() {
pbResHeader[k] = &pb.Strings{Values: v}
}
}
}

View File

@@ -2,6 +2,7 @@ package nodes
import (
"bufio"
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
@@ -9,8 +10,36 @@ import (
"net/url"
)
// WebsocketResponseReader Websocket响应Reader
type WebsocketResponseReader struct {
rawReader io.Reader
buf []byte
}
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
return &WebsocketResponseReader{
rawReader: rawReader,
}
}
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
if n > 0 {
if len(this.buf) == 0 {
this.buf = make([]byte, n)
copy(this.buf, p[:n])
} else {
this.buf = append(this.buf, p[:n]...)
}
}
return
}
// 处理Websocket请求
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
// 设置不缓存
this.web.Cache = nil
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
@@ -84,14 +113,20 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
go func() {
// 读取第一个响应
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
if err != nil {
var respReader = NewWebsocketResponseReader(originConn)
resp, err := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
if err != nil || resp == nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
// 将响应写回客户端
err = resp.Write(clientConn)
@@ -105,6 +140,25 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return
}
// 剩余已经从源站读取的内容
var headerBytes = respReader.buf
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
if headerIndex > 0 {
var leftBytes = headerBytes[headerIndex+4:]
if len(leftBytes) > 0 {
_, err = clientConn.Write(leftBytes)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
}
}
if resp.Body != nil {
_ = resp.Body.Close()
}

View File

@@ -160,7 +160,7 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
// 严格查找域名
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
group := this.Group
var group = this.Group
if group == nil {
return nil, ""
}

View File

@@ -45,7 +45,7 @@ func (this *HTTPListener) Serve() error {
Handler: this,
ReadTimeout: 1 * time.Hour, // TODO 改成可以配置
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
WriteTimeout: 1 * time.Hour, // TODO 改成可以配置
WriteTimeout: 2 * time.Hour, // TODO 改成可以配置
IdleTimeout: 75 * time.Second, // TODO 改成可以配置
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
@@ -175,6 +175,15 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
}
}
// 检查用户
if server != nil && server.UserId > 0 {
if !SharedUserManager.CheckUserServersIsEnabled(server.UserId) {
rawWriter.WriteHeader(http.StatusNotFound)
_, _ = rawWriter.Write([]byte("The site owner is unavailable."))
return
}
}
// 包装新请求对象
var req = &HTTPRequest{
RawReq: rawReq,

View File

@@ -92,7 +92,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// 是否已达到流量限制
if this.reachedTrafficLimit() {
if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {

View File

@@ -170,6 +170,11 @@ func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
return nil
}
// 检查用户状态
if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
return nil
}
n, cm, clientAddr, err := listener.ReadFrom(buffer)
if err != nil {
if this.isClosed {

View File

@@ -205,9 +205,7 @@ func (this *Node) Start() {
// 统计
goman.New(func() {
stats.SharedTrafficStatManager.Start(func() *nodeconfigs.NodeConfig {
return sharedNodeConfig
})
stats.SharedTrafficStatManager.Start()
})
goman.New(func() {
stats.SharedHTTPRequestStatManager.Start()
@@ -430,6 +428,22 @@ func (this *Node) execTask(rpcClient *rpc.RPCClient, nodeCtx context.Context, ta
}
}
}
case "userServersStateChanged":
if task.UserId > 0 {
resp, err := rpcClient.UserRPC.CheckUserServersState(nodeCtx, &pb.CheckUserServersStateRequest{UserId: task.UserId})
if err != nil {
return err
}
SharedUserManager.UpdateUserServersIsEnabled(task.UserId, resp.IsEnabled)
if resp.IsEnabled {
err = this.syncUserServersConfig(task.UserId)
if err != nil {
return err
}
}
}
default:
remotelogs.Error("NODE", "task '"+types.String(task.Id)+"', type '"+task.Type+"' has not been handled")
}
@@ -615,6 +629,36 @@ func (this *Node) syncServerConfig(serverId int64) error {
return nil
}
// 同步某个用户下的所有服务配置
func (this *Node) syncUserServersConfig(userId int64) error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
serverConfigsResp, err := rpcClient.ServerRPC.ComposeAllUserServersConfig(rpcClient.Context(), &pb.ComposeAllUserServersConfigRequest{
UserId: userId,
})
if err != nil {
return err
}
if len(serverConfigsResp.ServersConfigJSON) == 0 {
return nil
}
var serverConfigs = []*serverconfigs.ServerConfig{}
err = json.Unmarshal(serverConfigsResp.ServersConfigJSON, &serverConfigs)
if err != nil {
return err
}
this.locker.Lock()
defer this.locker.Unlock()
for _, config := range serverConfigs {
this.updatingServerMap[config.Id] = config
}
return nil
}
// 启动同步计时器
func (this *Node) startSyncTimer() {
// TODO 这个时间间隔可以自行设置

View File

@@ -195,8 +195,8 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
})
// 当前TeaWeb所在的fs
rootFS := ""
rootTotal := uint64(0)
var rootFS = ""
var rootTotal = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
@@ -210,9 +210,9 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
}
}
total := rootTotal
totalUsage := uint64(0)
maxUsage := float64(0)
var total = rootTotal
var totalUsage = uint64(0)
var maxUsage = float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
continue
@@ -252,7 +252,7 @@ func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
// 缓存空间
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
var result = []maps.Map{}
cachePaths := caches.SharedManager.FindAllCachePaths()
var cachePaths = caches.SharedManager.FindAllCachePaths()
for _, path := range cachePaths {
var stat unix.Statfs_t
err := unix.Statfs(path, &stat)

View File

@@ -1,5 +1,4 @@
//go:build !windows
// +build !windows
package nodes

View File

@@ -0,0 +1,49 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"sync"
)
var SharedUserManager = NewUserManager()
type User struct {
ServersEnabled bool
}
type UserManager struct {
userMap map[int64]*User // id => *User
locker sync.RWMutex
}
func NewUserManager() *UserManager {
return &UserManager{
userMap: map[int64]*User{},
}
}
func (this *UserManager) UpdateUserServersIsEnabled(userId int64, isEnabled bool) {
this.locker.Lock()
u, ok := this.userMap[userId]
if ok {
u.ServersEnabled = isEnabled
} else {
u = &User{ServersEnabled: isEnabled}
this.userMap[userId] = u
}
this.locker.Unlock()
}
func (this *UserManager) CheckUserServersIsEnabled(userId int64) (isEnabled bool) {
this.locker.RLock()
u, ok := this.userMap[userId]
if ok {
isEnabled = u.ServersEnabled
} else {
isEnabled = true
}
this.locker.RUnlock()
return
}

View File

@@ -49,6 +49,7 @@ type RPCClient struct {
FirewallRPC pb.FirewallServiceClient
SSLCertRPC pb.SSLCertServiceClient
ScriptRPC pb.ScriptServiceClient
UserRPC pb.UserServiceClient
}
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
@@ -81,6 +82,7 @@ func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
client.FirewallRPC = pb.NewFirewallServiceClient(client)
client.SSLCertRPC = pb.NewSSLCertServiceClient(client)
client.ScriptRPC = pb.NewScriptServiceClient(client)
client.UserRPC = pb.NewUserServiceClient(client)
err := client.init()
if err != nil {

View File

@@ -3,6 +3,7 @@
package stats
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
@@ -42,6 +43,8 @@ type BandwidthStat struct {
type BandwidthStatManager struct {
m map[string]*BandwidthStat // key => *BandwidthStat
pbStats []*pb.ServerBandwidthStat
lastTime string // 上一次执行的时间
ticker *time.Ticker
@@ -65,6 +68,12 @@ func (this *BandwidthStatManager) Start() {
}
func (this *BandwidthStatManager) Loop() error {
var regionId int64
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
regionId = nodeConfig.RegionId
}
var now = time.Now()
var day = timeutil.Format("Ymd", now)
var currentTime = timeutil.FormatTime("Hi", now.Unix()/300*300)
@@ -76,16 +85,29 @@ func (this *BandwidthStatManager) Loop() error {
var pbStats = []*pb.ServerBandwidthStat{}
// 历史未提交记录
if len(this.pbStats) > 0 {
var expiredTime = timeutil.FormatTime("Hi", time.Now().Unix()-1200) // 只保留20分钟
for _, stat := range this.pbStats {
if stat.TimeAt > expiredTime {
pbStats = append(pbStats, stat)
}
}
this.pbStats = nil
}
this.locker.Lock()
for key, stat := range this.m {
if stat.Day < day || stat.TimeAt < currentTime {
pbStats = append(pbStats, &pb.ServerBandwidthStat{
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
Id: 0,
UserId: stat.UserId,
ServerId: stat.ServerId,
Day: stat.Day,
TimeAt: stat.TimeAt,
Bytes: stat.MaxBytes / bandwidthTimestampDelim,
NodeRegionId: regionId,
})
delete(this.m, key)
}
@@ -100,6 +122,8 @@ func (this *BandwidthStatManager) Loop() error {
}
_, err = rpcClient.ServerBandwidthStatRPC.UploadServerBandwidthStats(rpcClient.Context(), &pb.UploadServerBandwidthStatsRequest{ServerBandwidthStats: pbStats})
if err != nil {
this.pbStats = pbStats
return err
}
}

View File

@@ -31,19 +31,33 @@ type TrafficItem struct {
CheckingTrafficLimit bool
}
func (this *TrafficItem) Add(anotherItem *TrafficItem) {
this.Bytes += anotherItem.Bytes
this.CachedBytes += anotherItem.CachedBytes
this.CountRequests += anotherItem.CountRequests
this.CountCachedRequests += anotherItem.CountCachedRequests
this.CountAttackRequests += anotherItem.CountAttackRequests
this.AttackBytes += anotherItem.AttackBytes
}
const trafficStatsMaxLife = 1200 // 最大只保存20分钟内的数据
// TrafficStatManager 区域流量统计
type TrafficStatManager struct {
itemMap map[string]*TrafficItem // [timestamp serverId] => *TrafficItem
domainsMap map[string]*TrafficItem // timestamp @ serverId @ domain => *TrafficItem
locker sync.Mutex
configFunc func() *nodeconfigs.NodeConfig
pbItems []*pb.ServerDailyStat
pbDomainItems []*pb.UploadServerDailyStatsRequest_DomainStat
locker sync.Mutex
totalRequests int64
}
// NewTrafficStatManager 获取新对象
func NewTrafficStatManager() *TrafficStatManager {
manager := &TrafficStatManager{
var manager = &TrafficStatManager{
itemMap: map[string]*TrafficItem{},
domainsMap: map[string]*TrafficItem{},
}
@@ -52,9 +66,7 @@ func NewTrafficStatManager() *TrafficStatManager {
}
// Start 启动自动任务
func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) {
this.configFunc = configFunc
func (this *TrafficStatManager) Start() {
// 上传请求总数
var monitorTicker = time.NewTicker(1 * time.Minute)
events.OnKey(events.EventQuit, this, func() {
@@ -70,7 +82,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
})
// 上传统计数据
duration := 5 * time.Minute
var duration = 5 * time.Minute
if Tea.IsTesting() {
// 测试环境缩短上传时间,方便我们调试
duration = 30 * time.Second
@@ -143,9 +155,10 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64,
// Upload 上传流量
func (this *TrafficStatManager) Upload() error {
var config = this.configFunc()
if config == nil {
return nil
var regionId int64
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
regionId = nodeConfig.RegionId
}
client, err := rpc.SharedRPC()
@@ -154,10 +167,14 @@ func (this *TrafficStatManager) Upload() error {
}
this.locker.Lock()
var itemMap = this.itemMap
var domainMap = this.domainsMap
// reset
this.itemMap = map[string]*TrafficItem{}
this.domainsMap = map[string]*TrafficItem{}
this.locker.Unlock()
// 服务统计
@@ -174,7 +191,7 @@ func (this *TrafficStatManager) Upload() error {
pbServerStats = append(pbServerStats, &pb.ServerDailyStat{
ServerId: serverId,
RegionId: config.RegionId,
NodeRegionId: regionId,
Bytes: item.Bytes,
CachedBytes: item.CachedBytes,
CountRequests: item.CountRequests,
@@ -186,9 +203,6 @@ func (this *TrafficStatManager) Upload() error {
CreatedAt: timestamp,
})
}
if len(pbServerStats) == 0 {
return nil
}
// 域名统计
var pbDomainStats = []*pb.UploadServerDailyStatsRequest_DomainStat{}
@@ -210,9 +224,40 @@ func (this *TrafficStatManager) Upload() error {
})
}
// 历史未提交记录
if len(this.pbItems) > 0 || len(this.pbDomainItems) > 0 {
var expiredAt = time.Now().Unix() - 1200 // 只保留20分钟
for _, item := range this.pbItems {
if item.CreatedAt > expiredAt {
pbServerStats = append(pbServerStats, item)
}
}
this.pbItems = nil
for _, item := range this.pbDomainItems {
if item.CreatedAt > expiredAt {
pbDomainStats = append(pbDomainStats, item)
}
}
this.pbDomainItems = nil
}
if len(pbServerStats) == 0 && len(pbDomainStats) == 0 {
return nil
}
_, err = client.ServerDailyStatRPC.UploadServerDailyStats(client.Context(), &pb.UploadServerDailyStatsRequest{
Stats: pbServerStats,
DomainStats: pbDomainStats,
})
return err
if err != nil {
// 加回历史记录
this.pbItems = pbServerStats
this.pbDomainItems = pbDomainStats
return err
}
return nil
}

View File

@@ -39,6 +39,7 @@ func init() {
}
type ClockManager struct {
lastFailAt int64
}
func NewClockManager() *ClockManager {
@@ -51,7 +52,13 @@ func (this *ClockManager) Start() {
for range ticker.C {
err := this.Sync()
if err != nil {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
var currentTimestamp = time.Now().Unix()
// 每天只提醒一次错误
if currentTimestamp-this.lastFailAt > 86400 {
remotelogs.Warn("CLOCK", "sync clock failed: "+err.Error())
this.lastFailAt = currentTimestamp
}
}
}
}
@@ -118,7 +125,7 @@ func (this *ClockManager) syncNtpdate(ntpdate string, server string) error {
return nil
}
// 参考自https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
// ReadServer 参考自https://medium.com/learning-the-go-programming-language/lets-make-an-ntp-client-in-go-287c4b9a969f
func (this *ClockManager) ReadServer(server string) (time.Time, error) {
conn, err := net.Dial("udp", server+":123")
if err != nil {

View File

@@ -0,0 +1,28 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package readers
import (
"io"
"log"
)
type PrintReader struct {
rawReader io.Reader
tag string
}
func NewPrintReader(rawReader io.Reader, tag string) io.Reader {
return &PrintReader{
rawReader: rawReader,
tag: tag,
}
}
func (this *PrintReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
if n > 0 {
log.Println("[" + this.tag + "]" + string(p[:n]))
}
return
}

View File

@@ -0,0 +1,28 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package writers
import (
"io"
"log"
)
type PrintWriter struct {
rawWriter io.Writer
tag string
}
func NewPrintWriter(rawWriter io.Writer, tag string) io.Writer {
return &PrintWriter{
rawWriter: rawWriter,
tag: tag,
}
}
func (this *PrintWriter) Write(p []byte) (n int, err error) {
n, err = this.rawWriter.Write(p)
if n > 0 {
log.Println("[" + this.tag + "]" + string(p[:n]))
}
return
}

View File

@@ -6,6 +6,7 @@ import (
)
type Checkpoint struct {
priority int
}
func (this *Checkpoint) Init() {
@@ -36,6 +37,14 @@ func (this *Checkpoint) Stop() {
}
func (this *Checkpoint) SetPriority(priority int) {
this.priority = priority
}
func (this *Checkpoint) Priority() int {
return this.priority
}
func (this *Checkpoint) RequestBodyIsEmpty(req requests.Request) bool {
if req.WAFRaw().ContentLength == 0 {
return true

View File

@@ -7,4 +7,5 @@ type CheckpointDefinition struct {
Prefix string
HasParams bool // has sub params
Instance CheckpointInterface
Priority int
}

View File

@@ -33,4 +33,10 @@ type CheckpointInterface interface {
// Stop stop
Stop()
// SetPriority set priority
SetPriority(priority int)
// get priority
Priority() int
}

View File

@@ -41,19 +41,45 @@ func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, pa
return
}
var domains = options.GetSlice("allowDomains")
var domainStrings = []string{}
for _, domain := range domains {
domainStrings = append(domainStrings, types.String(domain))
// allow domains
var allowDomains = options.GetSlice("allowDomains")
var allowDomainStrings = []string{}
for _, domain := range allowDomains {
allowDomainStrings = append(allowDomainStrings, types.String(domain))
}
if len(domainStrings) == 0 {
// deny domains
var denyDomains = options.GetSlice("denyDomains")
var denyDomainStrings = []string{}
for _, domain := range denyDomains {
denyDomainStrings = append(denyDomainStrings, types.String(domain))
}
if len(allowDomainStrings) == 0 {
if len(denyDomainStrings) > 0 {
if configutils.MatchDomains(denyDomainStrings, host) {
value = 0
} else {
value = 1
}
return
}
value = 0
return
}
if configutils.MatchDomains(domainStrings, host) {
if configutils.MatchDomains(allowDomainStrings, host) {
if len(denyDomainStrings) > 0 {
if configutils.MatchDomains(denyDomainStrings, host) {
value = 0
} else {
value = 1
}
return
}
value = 1
return
} else {
value = 0
}

View File

@@ -8,6 +8,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "通用Header比如Cache-Control、Accept之类的长度限制防止缓冲区溢出攻击",
HasParams: false,
Instance: new(RequestGeneralHeaderLengthCheckpoint),
Priority: 100,
},
{
Name: "客户端地址IP",
@@ -15,6 +16,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址比如192.168.1.100",
HasParams: false,
Instance: new(RequestRemoteAddrCheckpoint),
Priority: 100,
},
{
Name: "客户端源地址IP",
@@ -22,6 +24,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "直接连接的客户端地址比如192.168.1.100",
HasParams: false,
Instance: new(RequestRawRemoteAddrCheckpoint),
Priority: 100,
},
{
Name: "客户端端口",
@@ -29,6 +32,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "直接连接的客户端地址端口",
HasParams: false,
Instance: new(RequestRemotePortCheckpoint),
Priority: 100,
},
{
Name: "客户端用户名",
@@ -36,6 +40,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "通过BasicAuth登录的客户端用户名",
HasParams: false,
Instance: new(RequestRemoteUserCheckpoint),
Priority: 100,
},
{
Name: "请求URI",
@@ -43,6 +48,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "包含URL参数的请求URI类似于 /hello/world?lang=go",
HasParams: false,
Instance: new(RequestURICheckpoint),
Priority: 100,
},
{
Name: "请求路径",
@@ -50,6 +56,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "不包含URL参数的请求路径类似于 /hello/world",
HasParams: false,
Instance: new(RequestPathCheckpoint),
Priority: 100,
},
{
Name: "请求URL",
@@ -57,6 +64,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "完整的请求URL包含协议、域名、请求路径、参数等类似于 https://example.com/hello?name=lily",
HasParams: false,
Instance: new(RequestURLCheckpoint),
Priority: 100,
},
{
Name: "请求内容长度",
@@ -64,6 +72,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "请求Header中的Content-Length",
HasParams: false,
Instance: new(RequestLengthCheckpoint),
Priority: 100,
},
{
Name: "请求体内容",
@@ -71,6 +80,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "通常在POST或者PUT等操作时会附带请求体最大限制32M",
HasParams: false,
Instance: new(RequestBodyCheckpoint),
Priority: 5,
},
{
Name: "请求URI和请求体组合",
@@ -78,6 +88,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "${requestURI}和${requestBody}组合",
HasParams: false,
Instance: new(RequestAllCheckpoint),
Priority: 5,
},
{
Name: "请求表单参数",
@@ -85,6 +96,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "获取POST或者其他方法发送的表单参数最大请求体限制32M",
HasParams: true,
Instance: new(RequestFormArgCheckpoint),
Priority: 5,
},
{
Name: "上传文件",
@@ -92,6 +104,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "获取POST上传的文件信息最大请求体限制32M",
HasParams: true,
Instance: new(RequestUploadCheckpoint),
Priority: 20,
},
{
Name: "请求JSON参数",
@@ -99,6 +112,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "获取POST或者其他方法发送的JSON最大请求体限制32M使用点.)符号表示多级数据",
HasParams: true,
Instance: new(RequestJSONArgCheckpoint),
Priority: 5,
},
{
Name: "请求方法",
@@ -106,6 +120,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如GET、POST",
HasParams: false,
Instance: new(RequestMethodCheckpoint),
Priority: 100,
},
{
Name: "请求协议",
@@ -113,6 +128,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如http或https",
HasParams: false,
Instance: new(RequestSchemeCheckpoint),
Priority: 100,
},
{
Name: "HTTP协议版本",
@@ -120,6 +136,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如HTTP/1.1",
HasParams: false,
Instance: new(RequestProtoCheckpoint),
Priority: 100,
},
{
Name: "主机名",
@@ -127,6 +144,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如teaos.cn",
HasParams: false,
Instance: new(RequestHostCheckpoint),
Priority: 100,
},
{
Name: "请求来源URL",
@@ -134,6 +152,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "请求Header中的Referer值",
HasParams: false,
Instance: new(RequestRefererCheckpoint),
Priority: 100,
},
{
Name: "客户端信息",
@@ -141,6 +160,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103",
HasParams: false,
Instance: new(RequestUserAgentCheckpoint),
Priority: 100,
},
{
Name: "内容类型",
@@ -148,6 +168,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "请求Header的Content-Type",
HasParams: false,
Instance: new(RequestContentTypeCheckpoint),
Priority: 100,
},
{
Name: "所有cookie组合字符串",
@@ -155,6 +176,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如sid=IxZVPFhE&city=beijing&uid=18237",
HasParams: false,
Instance: new(RequestCookiesCheckpoint),
Priority: 100,
},
{
Name: "单个cookie值",
@@ -162,6 +184,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "单个cookie值",
HasParams: true,
Instance: new(RequestCookieCheckpoint),
Priority: 100,
},
{
Name: "所有URL参数组合",
@@ -169,6 +192,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "比如name=lu&age=20",
HasParams: false,
Instance: new(RequestArgsCheckpoint),
Priority: 100,
},
{
Name: "单个URL参数值",
@@ -176,6 +200,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "单个URL参数值",
HasParams: true,
Instance: new(RequestArgCheckpoint),
Priority: 100,
},
{
Name: "所有Header信息",
@@ -183,6 +208,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "使用\\n隔开的Header信息字符串",
HasParams: false,
Instance: new(RequestHeadersCheckpoint),
Priority: 100,
},
{
Name: "单个Header值",
@@ -190,6 +216,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "单个Header值",
HasParams: true,
Instance: new(RequestHeaderCheckpoint),
Priority: 100,
},
{
Name: "国家/地区名称",
@@ -197,6 +224,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "国家/地区名称",
HasParams: false,
Instance: new(RequestGeoCountryNameCheckpoint),
Priority: 90,
},
{
Name: "省份名称",
@@ -204,6 +232,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "中国省份名称",
HasParams: false,
Instance: new(RequestGeoProvinceNameCheckpoint),
Priority: 90,
},
{
Name: "城市名称",
@@ -211,6 +240,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "中国城市名称",
HasParams: false,
Instance: new(RequestGeoCityNameCheckpoint),
Priority: 90,
},
{
Name: "ISP名称",
@@ -218,6 +248,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "ISP名称",
HasParams: false,
Instance: new(RequestISPNameCheckpoint),
Priority: 90,
},
{
Name: "CC统计",
@@ -225,6 +256,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "统计某段时间段内的请求信息",
HasParams: true,
Instance: new(CCCheckpoint),
Priority: 10,
},
{
Name: "CC统计",
@@ -232,6 +264,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "统计某段时间段内的请求信息",
HasParams: true,
Instance: new(CC2Checkpoint),
Priority: 10,
},
{
Name: "防盗链",
@@ -239,6 +272,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "阻止一些域名访问引用本站资源",
HasParams: true,
Instance: new(RequestRefererBlockCheckpoint),
Priority: 20,
},
{
Name: "通用响应Header长度限制",
@@ -246,6 +280,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "通用Header比如Cache-Control、Accept之类的长度限制防止缓冲区溢出攻击",
HasParams: false,
Instance: new(ResponseGeneralHeaderLengthCheckpoint),
Priority: 100,
},
{
Name: "响应状态码",
@@ -253,6 +288,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "响应状态码比如200、404、500",
HasParams: false,
Instance: new(ResponseStatusCheckpoint),
Priority: 100,
},
{
Name: "响应Header",
@@ -260,6 +296,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "响应Header值",
HasParams: true,
Instance: new(ResponseHeaderCheckpoint),
Priority: 100,
},
{
Name: "响应内容",
@@ -267,6 +304,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "响应内容字符串",
HasParams: false,
Instance: new(ResponseBodyCheckpoint),
Priority: 5,
},
{
Name: "响应内容长度",
@@ -274,6 +312,7 @@ var AllCheckpoints = []*CheckpointDefinition{
Description: "响应内容长度通过响应的Header Content-Length获取",
HasParams: false,
Instance: new(ResponseBytesSentCheckpoint),
Priority: 100,
},
}
@@ -281,6 +320,7 @@ var AllCheckpoints = []*CheckpointDefinition{
func FindCheckpoint(prefix string) CheckpointInterface {
for _, def := range AllCheckpoints {
if def.Prefix == prefix {
def.Instance.SetPriority(def.Priority)
return def.Instance
}
}

View File

@@ -35,6 +35,7 @@ type Rule struct {
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"`
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
@@ -132,9 +133,9 @@ func (this *Rule) Init() error {
}
if singleParamRegexp.MatchString(this.Param) {
param := this.Param[2 : len(this.Param)-1]
pieces := strings.SplitN(param, ".", 2)
prefix := pieces[0]
var param = this.Param[2 : len(this.Param)-1]
var pieces = strings.SplitN(param, ".", 2)
var prefix = pieces[0]
if len(pieces) == 1 {
this.singleParam = ""
} else {
@@ -142,18 +143,20 @@ func (this *Rule) Init() error {
}
if this.checkpointFinder != nil {
checkpoint := this.checkpointFinder(prefix)
var checkpoint = this.checkpointFinder(prefix)
if checkpoint == nil {
return errors.New("no check point '" + prefix + "' found")
}
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
} else {
checkpoint := checkpoints.FindCheckpoint(prefix)
var checkpoint = checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
return errors.New("no check point '" + prefix + "' found")
}
checkpoint.Init()
this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority()
}
return nil
@@ -162,22 +165,24 @@ func (this *Rule) Init() error {
this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{}
var err error = nil
configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
var pieces = strings.SplitN(varName, ".", 2)
var prefix = pieces[0]
if this.checkpointFinder != nil {
checkpoint := this.checkpointFinder(prefix)
var checkpoint = this.checkpointFinder(prefix)
if checkpoint == nil {
err = errors.New("no check point '" + prefix + "' found")
} else {
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
}
} else {
checkpoint := checkpoints.FindCheckpoint(prefix)
var checkpoint = checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
err = errors.New("no check point '" + prefix + "' found")
} else {
checkpoint.Init()
this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority()
}
}
return ""

View File

@@ -52,6 +52,11 @@ func (this *RuleSet) Init(waf *WAF) error {
return errors.New("init rule '" + rule.Param + " " + rule.Operator + " " + types.String(rule.Value) + "' failed: " + err.Error())
}
}
// sort by priority
sort.Slice(this.Rules, func(i, j int) bool {
return this.Rules[i].Priority > this.Rules[j].Priority
})
}
// action codes

View File

@@ -73,6 +73,7 @@ func (this *WAF) Init() (resultErrors []error) {
for _, def := range checkpoints.AllCheckpoints {
instance := reflect.New(reflect.Indirect(reflect.ValueOf(def.Instance)).Type()).Interface().(checkpoints.CheckpointInterface)
instance.Init()
instance.SetPriority(def.Priority)
this.checkpointsMap[def.Prefix] = instance
}