Compare commits

...

45 Commits

Author SHA1 Message Date
刘祥超
77bb1cf14e 版本更改为0.4.8.1 2022-06-20 09:34:06 +08:00
刘祥超
274284dbe1 静态文件分发也支持压缩、WebP转换 2022-06-19 11:39:21 +08:00
刘祥超
cd6d7221e8 不限制206 Partial Content两次写入文件的时间差 2022-06-18 20:05:09 +08:00
刘祥超
b1d0c8852e 如果缓存条件支持206 Partial Content,则第一次加载时自动尝试从分片缓存中读取内容 2022-06-18 19:31:10 +08:00
刘祥超
4c4033bb56 TCP负载均衡实现流量限制,达到限制后,关闭连接 2022-06-17 21:49:15 +08:00
刘祥超
6d642b75f6 延长预热超时时间 2022-06-16 20:32:11 +08:00
刘祥超
eb47e3a08c 删除缓存的时同时删除相关的缓存(压缩格式、WebP格式、http和https互换URL) 2022-06-15 12:54:56 +08:00
刘祥超
d82d16e28d 修复内容为空时无法缓存的Bug 2022-06-09 20:26:36 +08:00
刘祥超
b2fc785543 WAF规则中增加${requestURL}参数 2022-06-09 19:44:11 +08:00
刘祥超
189e3342ce 将缓存maxOpenFiles最小值从2改为4 2022-06-09 19:12:29 +08:00
刘祥超
885defbf31 优化nftables相关代码 2022-06-09 19:12:10 +08:00
刘祥超
74f1bf330d DNS解析库默认使用Go原生库 2022-06-07 11:49:38 +08:00
刘祥超
ad843d9d10 修复源站从http跳转到https导致无限循环的问题 2022-06-07 11:25:09 +08:00
刘祥超
13e718742d 节点状态上报时增加时间戳字段 2022-06-07 11:23:40 +08:00
刘祥超
771eff8fb1 增加刷新、预热缓存任务管理 2022-06-05 17:15:02 +08:00
刘祥超
20d7e0b1bf ACME申请证书时如果找不到Token,则直接跳过执行后面请求 2022-06-02 15:34:14 +08:00
刘祥超
e6c7bbec06 修复一个源站主备切换不灵敏的问题/WebSocket也支持源站主备自动切换 2022-05-23 20:01:26 +08:00
刘祥超
be61ef89fe fix typo 2022-05-23 16:15:02 +08:00
刘祥超
3d7d8f1e63 优化代码 2022-05-23 11:34:58 +08:00
刘祥超
a4fb465a19 在严格匹配域名模式下仍然可以通过节点IP进行健康检查 2022-05-23 11:17:53 +08:00
刘祥超
96c725c13d 增加LICENSE和README 2022-05-22 11:36:43 +08:00
刘祥超
7635def2fa 修正自动使用本地防火墙延长封禁时间逻辑 2022-05-21 22:15:11 +08:00
刘祥超
b704a73338 修复一个日志typo 2022-05-21 22:04:23 +08:00
刘祥超
123b5f5969 自动将同集群节点IP加入白名单/尝试使用本地防火墙提升黑名单连接封锁效率 2022-05-21 21:32:10 +08:00
刘祥超
eea2037444 优化验证码失败次数统计 2022-05-21 20:02:35 +08:00
刘祥超
4e6d2fa5ea WAF策略中增加验证码相关定制设置 2022-05-21 11:17:53 +08:00
刘祥超
14bb131e8d WAF CAPTCHA:刷新验证码页面也算入校验失败次数 2022-05-20 11:56:06 +08:00
刘祥超
31814bb54c 忽略301, 302, 303, 307, 308响应中没有Location的错误提示 2022-05-19 20:16:40 +08:00
刘祥超
49b8fd6e97 健康检查增加是否记录访问日志选项 2022-05-19 17:13:20 +08:00
刘祥超
a9d31a2e35 增加edge-node accesslog命令,用来在本地查看访问日志 2022-05-18 23:14:57 +08:00
刘祥超
298cef7f05 缩短指标统计队列长度 2022-05-18 21:41:34 +08:00
刘祥超
9bdd9a433c 实现基础的DDoS防护 2022-05-18 21:03:51 +08:00
刘祥超
45620dcdb7 计算CC的时候不再跨时间范围累积 2022-05-12 21:48:33 +08:00
刘祥超
84a5d38b0b 在启动时检查节点时间戳是否和API节点一致,如果不一致则上报 2022-05-12 21:07:45 +08:00
刘祥超
e812b3fcf6 X-Forwarded-For中包含当前客户端的IP 2022-05-08 16:56:10 +08:00
刘祥超
1bd16fa1d3 往硬盘刷数据时不统计maxOpenFiles 2022-05-07 22:02:41 +08:00
刘祥超
f3ea4957be fix typo 2022-05-05 11:01:03 +08:00
刘祥超
04da107c94 路由规则可以单独设置UAM(仅企业版可用) 2022-05-04 20:32:25 +08:00
刘祥超
e77de69a15 节点增加DNS解析库类型设置 2022-05-04 16:40:25 +08:00
刘祥超
e88eda56f5 自动替换Location中的地址时检查域名是否为当前域名 2022-05-04 10:48:14 +08:00
刘祥超
d6ceccc52e 增加基准测试 2022-05-01 10:40:19 +08:00
刘祥超
cd948ac68c 实现新的gzip库提升gzip性能 2022-04-30 22:22:30 +08:00
刘祥超
9eac8afa3d 停止节点时systemctl命令不阻塞当前进程 2022-04-26 12:22:00 +08:00
刘祥超
fb3610966a 白名单中的IP不受请求限制的影响 2022-04-25 11:11:25 +08:00
刘祥超
a6673449db 修改版本为0.4.8 2022-04-25 11:11:03 +08:00
102 changed files with 4084 additions and 679 deletions

29
LICENSE Normal file
View File

@@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2020, LiuXiangChao
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

1
README.md Normal file
View File

@@ -0,0 +1 @@
GoEdge边缘节点源码

View File

@@ -22,7 +22,7 @@ func main() {
app := apps.NewAppCmd().
Version(teaconst.Version).
Product(teaconst.ProductName).
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof]").
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog]").
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc]").
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove] IP")
@@ -258,6 +258,53 @@ func main() {
}
}
})
app.On("accesslog", func() {
// local sock
var tmpDir = os.TempDir()
var sockFile = tmpDir + "/" + teaconst.AccessLogSockName
_, err := os.Stat(sockFile)
if err != nil {
if !os.IsNotExist(err) {
fmt.Println("[ERROR]" + err.Error())
return
}
}
var processSock = gosock.NewTmpSock(teaconst.ProcessName)
reply, err := processSock.Send(&gosock.Command{
Code: "accesslog",
})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
return
}
if reply.Code == "error" {
var errString = maps.NewMap(reply.Params).GetString("error")
if len(errString) > 0 {
fmt.Println("[ERROR]" + errString)
return
}
}
conn, err := net.Dial("unix", sockFile)
if err != nil {
fmt.Println("[ERROR]start reading access log failed: " + err.Error())
return
}
defer func() {
_ = conn.Close()
}()
var buf = make([]byte, 1024)
for {
n, err := conn.Read(buf)
if n > 0 {
fmt.Print(string(buf[:n]))
}
if err != nil {
break
}
}
})
app.Run(func() {
node := nodes.NewNode()
node.Start()

5
go.mod
View File

@@ -2,7 +2,9 @@ module github.com/TeaOSLab/EdgeNode
go 1.15
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
replace (
github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
)
require (
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
@@ -20,6 +22,7 @@ require (
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/klauspost/compress v1.15.2 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.9
github.com/miekg/dns v1.1.43

2
go.sum
View File

@@ -126,6 +126,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4h
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
github.com/klauspost/compress v1.15.2 h1:3WH+AG7s2+T8o3nrM/8u2rdqUEcQhmga7smjrT41nAw=
github.com/klauspost/compress v1.15.2/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=

View File

@@ -217,7 +217,10 @@ func (this *AppCmd) runStop() {
if runtime.GOOS == "linux" {
systemctl, _ := exec.LookPath("systemctl")
if len(systemctl) > 0 {
_ = exec.Command(systemctl, "stop", teaconst.SystemdServiceName).Run()
go func() {
// 有可能会长时间执行,这里不阻塞进程
_ = exec.Command(systemctl, "stop", teaconst.SystemdServiceName).Run()
}()
}
}

View File

@@ -214,3 +214,15 @@ func (this *Manager) FindAllCachePaths() []string {
}
return result
}
// FindAllStorages 读取所有缓存存储
func (this *Manager) FindAllStorages() []StorageInterface {
this.locker.Lock()
defer this.locker.Unlock()
var storages = []StorageInterface{}
for _, storage := range this.storageMap {
storages = append(storages, storage)
}
return storages
}

View File

@@ -10,7 +10,7 @@ import (
)
const (
minOpenFilesValue int32 = 2
minOpenFilesValue int32 = 4
maxOpenFilesValue int32 = 65535
modeSlow int32 = 1
@@ -35,7 +35,7 @@ func NewMaxOpenFiles(step int32) *MaxOpenFiles {
}
var f = &MaxOpenFiles{
step: step,
maxOpenFiles: 2,
maxOpenFiles: minOpenFilesValue,
}
if teaconst.DiskIsFast {
f.maxOpenFiles = 32
@@ -68,7 +68,7 @@ func (this *MaxOpenFiles) init() {
atomic.StoreInt32(&this.currentOpens, 0)
}
// reset mod
// reset mode
atomic.StoreInt32(&this.mode, 0)
}
})

View File

@@ -428,7 +428,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
return nil, ErrFileIsWriting
}
if len(sharedWritingFileKeyMap) >= int(maxOpenFiles.Max()) {
if !isFlushing && len(sharedWritingFileKeyMap) >= int(maxOpenFiles.Max()) {
sharedWritingFileKeyLocker.Unlock()
return nil, ErrTooManyOpenFiles
}
@@ -481,11 +481,16 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
openFileCache.Close(cachePath)
}
// 查询当前已有缓存文件
stat, err := os.Stat(cachePath)
if err == nil && time.Now().Sub(stat.ModTime()) <= 1*time.Second {
// 检查两次写入缓存的时间是否过于相近,分片内容不受此限制
if err == nil && !isPartial && time.Now().Sub(stat.ModTime()) <= 1*time.Second {
// 防止并发连续写入
return nil, ErrFileIsWriting
}
// 构造文件名
var tmpPath = cachePath
var existsFile = false
if stat != nil {
@@ -534,10 +539,12 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
if err != nil {
return nil, err
}
if time.Since(before) >= maxOpenFilesSlowCost {
maxOpenFiles.Slow()
} else {
maxOpenFiles.Fast()
if !isFlushing {
if time.Since(before) >= maxOpenFilesSlowCost {
maxOpenFiles.Slow()
} else {
maxOpenFiles.Fast()
}
}
var removeOnFailure = true
@@ -767,9 +774,10 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
return err
}
}
return nil
}
// 文件
// URL
for _, key := range keys {
hash, path := this.keyPath(key)
err := this.removeCacheFile(path)

View File

@@ -35,6 +35,7 @@ type StorageInterface interface {
CleanAll() error
// Purge 批量删除缓存
// urlType 值为file|dir
Purge(keys []string, urlType string) error
// Stop 停止缓存策略

View File

@@ -267,8 +267,10 @@ func (this *MemoryStorage) Purge(keys []string, urlType string) error {
return err
}
}
return nil
}
// URL
for _, key := range keys {
err := this.Delete(key)
if err != nil {

View File

@@ -3,7 +3,7 @@
package compressions
import (
"compress/gzip"
"github.com/klauspost/compress/gzip"
"io"
)

View File

@@ -3,7 +3,7 @@
package compressions
import (
"compress/gzip"
"github.com/klauspost/compress/gzip"
"io"
)

View File

@@ -34,3 +34,31 @@ func BenchmarkGzipWriter_Write(b *testing.B) {
_ = writer.Close()
}
}
func BenchmarkGzipWriter_Write_Parallel(b *testing.B) {
var data = []byte(strings.Repeat("A", 1024))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var buf = &bytes.Buffer{}
writer, err := compressions.NewGzipWriter(buf, 5)
if err != nil {
b.Fatal(err)
}
for j := 0; j < 100; j++ {
_, err = writer.Write(data)
if err != nil {
b.Fatal(err)
}
/**err = writer.Flush()
if err != nil {
b.Fatal(err)
}**/
}
_ = writer.Close()
}
})
}

View File

@@ -1,6 +1,6 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build community
// +build community
//go:build !plus
// +build !plus
package teaconst

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "0.4.7"
Version = "0.4.8.1"
ProductName = "Edge Node"
ProcessName = "edge-node"
@@ -13,4 +13,6 @@ const (
// SystemdServiceName systemd
SystemdServiceName = "edge-node"
AccessLogSockName = "edge-node.accesslog.sock"
)

View File

@@ -3,9 +3,10 @@ package events
type Event = string
const (
EventStart Event = "start" // start loading
EventLoaded Event = "loaded" // first load
EventQuit Event = "quit" // quit node gracefully
EventReload Event = "reload" // reload config
EventTerminated Event = "terminated" // process terminated
EventStart Event = "start" // start loading
EventLoaded Event = "loaded" // first load
EventQuit Event = "quit" // quit node gracefully
EventReload Event = "reload" // reload config
EventTerminated Event = "terminated" // process terminated
EventNFTablesReady Event = "nftablesReady" // nftables ready
)

View File

@@ -1 +0,0 @@
firewall_nftables_test.go

View File

@@ -0,0 +1,502 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package firewalls
import (
"bytes"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"net"
"os/exec"
"strings"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
func init() {
events.On(events.EventReload, func() {
if nftablesInstance == nil {
return
}
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
events.On(events.EventNFTablesReady, func() {
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
}
// DDoSProtectionManager DDoS防护
type DDoSProtectionManager struct {
nftPath string
lastAllowIPList []string
lastConfig []byte
}
// NewDDoSProtectionManager 获取新对象
func NewDDoSProtectionManager() *DDoSProtectionManager {
nftPath, _ := exec.LookPath("nft")
return &DDoSProtectionManager{
nftPath: nftPath,
}
}
// Apply 应用配置
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
// 同集群节点IP白名单
var allowIPListChanged = false
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
var allowIPList = nodeConfig.AllowedIPs
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
allowIPListChanged = true
this.lastAllowIPList = allowIPList
}
}
// 对比配置
configJSON, err := json.Marshal(config)
if err != nil {
return errors.New("encode config to json failed: " + err.Error())
}
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
return nil
}
remotelogs.Println("FIREWALL", "change DDoS protection config")
if len(this.nftPath) == 0 {
return errors.New("can not find nft command")
}
if nftablesInstance == nil {
return errors.New("nftables instance should not be nil")
}
if config == nil {
// TCP
err := this.removeTCPRules()
if err != nil {
return err
}
// TODO other protocols
return nil
}
// TCP
if config.TCP == nil {
err := this.removeTCPRules()
if err != nil {
return err
}
} else {
// allow ip list
var allowIPList = []string{}
for _, ipConfig := range config.TCP.AllowIPList {
allowIPList = append(allowIPList, ipConfig.IP)
}
for _, ip := range this.lastAllowIPList {
if !lists.ContainsString(allowIPList, ip) {
allowIPList = append(allowIPList, ip)
}
}
err = this.updateAllowIPList(allowIPList)
if err != nil {
return err
}
// tcp
if config.TCP.IsOn {
err := this.addTCPRules(config.TCP)
if err != nil {
return err
}
} else {
err := this.removeTCPRules()
if err != nil {
return err
}
}
}
this.lastConfig = configJSON
return nil
}
// 添加TCP规则
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
// 检查nft版本不能小于0.9
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
return nil
}
var ports = []int32{}
for _, portConfig := range tcpConfig.Ports {
if !lists.ContainsInt32(ports, portConfig.Port) {
ports = append(ports, portConfig.Port)
}
}
if len(ports) == 0 {
ports = []int32{80, 443}
}
for _, filter := range nftablesFilters {
chain, oldRules, err := this.getRules(filter)
if err != nil {
return errors.New("get old rules failed: " + err.Error())
}
var protocol = filter.protocol()
// max connections
var maxConnections = tcpConfig.MaxConnections
if maxConnections <= 0 {
maxConnections = nodeconfigs.DefaultTCPMaxConnections
if maxConnections <= 0 {
maxConnections = 100000
}
}
// max connections per ip
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = 100000
}
}
// new connections rate
var newConnectionsRate = tcpConfig.NewConnectionsRate
if newConnectionsRate <= 0 {
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
if newConnectionsRate <= 0 {
newConnectionsRate = 100000
}
}
// 检查是否有变化
var hasChanges = false
for _, port := range ports {
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
hasChanges = true
break
}
}
if !hasChanges {
// 检查是否有多余的端口
var oldPorts = this.getTCPPorts(oldRules)
if !this.eqPorts(ports, oldPorts) {
hasChanges = true
}
}
if !hasChanges {
return nil
}
// 先清空所有相关规则
err = this.removeOldTCPRules(chain, oldRules)
if err != nil {
return errors.New("delete old rules failed: " + err.Error())
}
// 添加新规则
for _, port := range ports {
if maxConnections > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
}
}
if maxConnectionsPerIP > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
}
}
if newConnectionsRate > 0 {
// TODO 思考是否有惩罚机制
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
}
}
}
}
return nil
}
// 删除TCP规则
func (this *DDoSProtectionManager) removeTCPRules() error {
for _, filter := range nftablesFilters {
chain, rules, err := this.getRules(filter)
// TCP
err = this.removeOldTCPRules(chain, rules)
if err != nil {
return err
}
}
return nil
}
// 组合user data
// 数据中不能包含字母、数字、下划线以外的数据
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
if attrs == nil {
return ""
}
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
}
// 解码user data
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
if len(data) == 0 {
return nil
}
var dataCopy = make([]byte, len(data))
copy(dataCopy, data)
var separatorLen = 2
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
if index1 < 0 {
return nil
}
dataCopy = dataCopy[index1+separatorLen:]
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
if index2 < 0 {
return nil
}
var s = string(dataCopy[:index2])
var pieces = strings.Split(s, "_")
for index, piece := range pieces {
pieces[index] = strings.TrimSpace(piece)
}
return pieces
}
// 清除规则
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) != 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
switch pieces[2] {
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
err := chain.DeleteRule(rule)
if err != nil {
return err
}
}
}
return nil
}
// 根据参数检查规则是否存在
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
if len(attrs) == 0 {
return false
}
for _, oldRule := range rules {
var pieces = this.decodeUserData(oldRule.UserData())
if len(attrs) != len(pieces) {
continue
}
var isSame = true
for index, piece := range pieces {
if strings.TrimSpace(piece) != attrs[index] {
isSame = false
break
}
}
if isSame {
return true
}
}
return false
}
// 获取规则中的端口号
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
var ports = []int32{}
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) != 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
var port = types.Int32(pieces[1])
if port > 0 && !lists.ContainsInt32(ports, port) {
ports = append(ports, port)
}
}
return ports
}
// 检查端口是否一样
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
if len(ports1) != len(ports2) {
return false
}
var portMap = map[int32]bool{}
for _, port := range ports2 {
portMap[port] = true
}
for _, port := range ports1 {
_, ok := portMap[port]
if !ok {
return false
}
}
return true
}
// 查找Table
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
var family nftables.TableFamily
if filter.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if filter.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
}
return nftablesInstance.conn.GetTable(filter.Name, family)
}
// 查找所有规则
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
table, err := this.getTable(filter)
if err != nil {
return nil, nil, errors.New("get table failed: " + err.Error())
}
chain, err := table.GetChain(nftablesChainName)
if err != nil {
return nil, nil, errors.New("get chain failed: " + err.Error())
}
rules, err := chain.GetRules()
return chain, rules, err
}
// 更新白名单
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
if nftablesInstance == nil {
return nil
}
var allMap = map[string]zero.Zero{}
for _, ip := range allIPList {
allMap[ip] = zero.New()
}
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
var isIPv4 = set == nftablesInstance.allowIPv4Set
var isIPv6 = !isIPv4
// 现有的
oldList, err := set.GetIPElements()
if err != nil {
return err
}
var oldMap = map[string]zero.Zero{} // ip=> zero
for _, ip := range oldList {
oldMap[ip] = zero.New()
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := allMap[ip]
if !ok {
// 不存在则删除
err = set.DeleteIPElement(ip)
if err != nil {
return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
}
}
}
}
// 新增的
for _, ip := range allIPList {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
continue
}
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := oldMap[ip]
if !ok {
// 不存在则添加
err = set.AddIPElement(ip, nil)
if err != nil {
return errors.New("add ip '" + ip + "' failed: " + err.Error())
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,23 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
// +build !linux
package firewalls
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
type DDoSProtectionManager struct {
nftPath string
}
func NewDDoSProtectionManager() *DDoSProtectionManager {
return &DDoSProtectionManager{}
}
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
return nil
}

View File

@@ -1,6 +1,4 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
// +build !plus
package firewalls
@@ -8,9 +6,11 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"runtime"
"sync"
)
var currentFirewall FirewallInterface
var firewallLocker = &sync.Mutex{}
// 初始化
func init() {
@@ -24,10 +24,28 @@ func init() {
// Firewall 查找当前系统中最适合的防火墙
func Firewall() FirewallInterface {
firewallLocker.Lock()
defer firewallLocker.Unlock()
if currentFirewall != nil {
return currentFirewall
}
// nftables
if runtime.GOOS == "linux" {
nftables, err := NewNFTablesFirewall()
if err != nil {
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
} else {
if nftables.IsReady() {
currentFirewall = nftables
events.Notify(events.EventNFTablesReady)
return nftables
} else {
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
}
}
}
// firewalld
if runtime.GOOS == "linux" {
var firewalld = NewFirewalld()

View File

@@ -23,12 +23,13 @@ func NewFirewalld() *Firewalld {
path, err := exec.LookPath("firewall-cmd")
if err == nil && len(path) > 0 {
var cmd = exec.Command(path, "-V")
var cmd = exec.Command(path, "--state")
err := cmd.Run()
if err == nil {
firewalld.exe = path
// TODO check firewalld status with 'firewall-cmd --state' (running or not running),
// but we should recover the state when firewalld state changes, maybe check it every minutes
firewalld.isReady = true
firewalld.init()
}

View File

@@ -0,0 +1,395 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package firewalls
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
)
// check nft status, if being enabled we load it automatically
func init() {
if runtime.GOOS == "linux" {
var ticker = time.NewTicker(3 * time.Minute)
go func() {
for range ticker.C {
// if already ready, we break
if nftablesIsReady {
ticker.Stop()
break
}
_, err := exec.LookPath("nft")
if err == nil {
nftablesFirewall, err := NewNFTablesFirewall()
if err != nil {
continue
}
currentFirewall = nftablesFirewall
remotelogs.Println("FIREWALL", "nftables is ready")
// fire event
if nftablesFirewall.IsReady() {
events.Notify(events.EventNFTablesReady)
}
ticker.Stop()
break
}
}
}()
}
}
var nftablesInstance *NFTablesFirewall
var nftablesIsReady = false
var nftablesFilters = []*nftablesTableDefinition{
// we shorten the name for table name length restriction
{Name: "edge_dft_v4", IsIPv4: true},
{Name: "edge_dft_v6", IsIPv6: true},
}
var nftablesChainName = "input"
type nftablesTableDefinition struct {
Name string
IsIPv4 bool
IsIPv6 bool
}
func (this *nftablesTableDefinition) protocol() string {
if this.IsIPv6 {
return "ip6"
}
return "ip"
}
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
var firewall = &NFTablesFirewall{
conn: nftables.NewConn(),
}
err := firewall.init()
if err != nil {
return nil, err
}
return firewall, nil
}
type NFTablesFirewall struct {
conn *nftables.Conn
isReady bool
version string
allowIPv4Set *nftables.Set
allowIPv6Set *nftables.Set
denyIPv4Set *nftables.Set
denyIPv6Set *nftables.Set
firewalld *Firewalld
}
func (this *NFTablesFirewall) init() error {
// check nft
nftPath, err := exec.LookPath("nft")
if err != nil {
return errors.New("nft not found")
}
this.version = this.readVersion(nftPath)
// table
for _, tableDef := range nftablesFilters {
var family nftables.TableFamily
if tableDef.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if tableDef.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return errors.New("invalid table family: " + types.String(tableDef))
}
table, err := this.conn.GetTable(tableDef.Name, family)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
table, err = this.conn.AddIPv4Table(tableDef.Name)
} else if tableDef.IsIPv6 {
table, err = this.conn.AddIPv6Table(tableDef.Name)
}
if err != nil {
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
}
} else {
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
}
}
if table == nil {
return errors.New("can not create table '" + tableDef.Name + "'")
}
// chain
var chainName = nftablesChainName
chain, err := table.GetChain(chainName)
if err != nil {
if nftables.IsNotFound(err) {
chain, err = table.AddAcceptChain(chainName)
if err != nil {
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
}
} else {
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
}
}
if chain == nil {
return errors.New("can not create chain '" + chainName + "'")
}
// allow lo
var loRuleName = []byte("lo")
_, err = chain.GetRuleWithUserData(loRuleName)
if err != nil {
if nftables.IsNotFound(err) {
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
}
if err != nil {
return errors.New("add 'lo' rule failed: " + err.Error())
}
}
// allow set
// "allow" should be always first
for _, setAction := range []string{"allow", "deny"} {
var setName = setAction + "_set"
set, err := table.GetSet(setName)
if err != nil {
if nftables.IsNotFound(err) {
var keyType nftables.SetDataType
if tableDef.IsIPv4 {
keyType = nftables.TypeIPAddr
} else if tableDef.IsIPv6 {
keyType = nftables.TypeIP6Addr
}
set, err = table.AddSet(setName, &nftables.SetOptions{
KeyType: keyType,
HasTimeout: true,
})
if err != nil {
return errors.New("create set '" + setName + "' failed: " + err.Error())
}
} else {
return errors.New("get set '" + setName + "' failed: " + err.Error())
}
}
if set == nil {
return errors.New("can not create set '" + setName + "'")
}
if tableDef.IsIPv4 {
if setAction == "allow" {
this.allowIPv4Set = set
} else {
this.denyIPv4Set = set
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
this.allowIPv6Set = set
} else {
this.denyIPv6Set = set
}
}
// rule
var ruleName = []byte(setAction)
rule, err := chain.GetRuleWithUserData(ruleName)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
} else {
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
} else {
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
}
}
if err != nil {
return errors.New("add rule failed: " + err.Error())
}
} else {
return errors.New("get rule failed: " + err.Error())
}
}
if rule == nil {
return errors.New("can not create rule '" + string(ruleName) + "'")
}
}
}
this.isReady = true
nftablesIsReady = true
nftablesInstance = this
// load firewalld
var firewalld = NewFirewalld()
if firewalld.IsReady() {
this.firewalld = firewalld
}
return nil
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return this.isReady
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return false
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.AllowPort(port, protocol)
}
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.RemovePort(port, protocol)
}
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.allowIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.allowIPv6Set.AddElement(data.To16(), nil)
}
// ipv4
if this.allowIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.allowIPv4Set.AddElement(data.To4(), nil)
}
// RejectSourceIP 拒绝某个源IP连接
// we did not create set for drop ip, so we reuse DropSourceIP() method here
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return this.DropSourceIP(ip, timeoutSeconds)
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// ipv4
if this.denyIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set != nil {
err := this.denyIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
if this.allowIPv6Set != nil {
err := this.allowIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
return nil
}
// ipv4
if this.allowIPv4Set != nil {
err := this.denyIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
err = this.allowIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
}
return nil
}
// 读取版本号
func (this *NFTablesFirewall) readVersion(nftPath string) string {
var cmd = exec.Command(nftPath, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err := cmd.Run()
if err != nil {
return ""
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
return ""
}
return versionMatches[1]
}

View File

@@ -0,0 +1,61 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
// +build !linux
package firewalls
import (
"errors"
)
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
return nil, errors.New("not implemented")
}
type NFTablesFirewall struct {
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return false
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return true
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
return nil
}
// RejectSourceIP 拒绝某个源IP连接
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
return nil
}

View File

@@ -0,0 +1 @@
build_remote.sh

View File

@@ -0,0 +1,370 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"bytes"
"errors"
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
const MaxChainNameLength = 31
type RuleOptions struct {
Exprs []expr.Any
UserData []byte
}
// Chain chain object in table
type Chain struct {
conn *Conn
rawTable *nft.Table
rawChain *nft.Chain
}
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
return &Chain{
conn: conn,
rawTable: rawTable,
rawChain: rawChain,
}
}
func (this *Chain) Raw() *nft.Chain {
return this.rawChain
}
func (this *Chain) Name() string {
return this.rawChain.Name
}
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
Table: this.rawTable,
Chain: this.rawChain,
Exprs: options.Exprs,
UserData: options.UserData,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewRule(rawRule), nil
}
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
if len(interfaceName) >= 16 {
return nil, errors.New("invalid interface name '" + interfaceName + "'")
}
var ifname = make([]byte, 16)
copy(ifname, interfaceName+"\x00")
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
for _, rawRule := range rawRules {
if bytes.Compare(rawRule.UserData, userData) == 0 {
return NewRule(rawRule), nil
}
}
return nil, ErrRuleNotFound
}
func (this *Chain) GetRules() ([]*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
var result = []*Rule{}
for _, rawRule := range rawRules {
result = append(result, NewRule(rawRule))
}
return result, nil
}
func (this *Chain) DeleteRule(rule *Rule) error {
err := this.conn.Raw().DelRule(rule.Raw())
if err != nil {
return err
}
return this.conn.Commit()
}
func (this *Chain) Flush() error {
this.conn.Raw().FlushChain(this.rawChain)
return this.conn.Commit()
}

View File

@@ -0,0 +1,13 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import nft "github.com/google/nftables"
type ChainPolicy = nft.ChainPolicy
// Possible ChainPolicy values.
const (
ChainPolicyDrop = nft.ChainPolicyDrop
ChainPolicyAccept = nft.ChainPolicyAccept
)

View File

@@ -0,0 +1,130 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"net"
"testing"
)
func getIPv4Chain(t *testing.T) *nftables.Chain {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
chain, err := table.GetChain("test_chain")
if err != nil {
if err == nftables.ErrChainNotFound {
chain, err = table.AddAcceptChain("test_chain")
}
}
if err != nil {
t.Fatal(err)
}
return chain
}
func TestChain_AddAcceptIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddAcceptSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddRejectSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_GetRuleWithUserData(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
} else {
t.Fatal(err)
}
}
t.Log("rule:", rule)
}
func TestChain_GetRules(t *testing.T) {
var chain = getIPv4Chain(t)
rules, err := chain.GetRules()
if err != nil {
t.Fatal(err)
}
for _, rule := range rules {
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
}
}
func TestChain_DeleteRule(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
}
t.Fatal(err)
}
err = chain.DeleteRule(rule)
if err != nil {
t.Fatal(err)
}
}
func TestChain_Flush(t *testing.T) {
var chain = getIPv4Chain(t)
err := chain.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,84 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
)
const MaxTableNameLength = 27
type Conn struct {
rawConn *nft.Conn
}
func NewConn() *Conn {
return &Conn{
rawConn: &nft.Conn{},
}
}
func (this *Conn) Raw() *nft.Conn {
return this.rawConn
}
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
rawTables, err := this.rawConn.ListTables()
if err != nil {
return nil, err
}
for _, rawTable := range rawTables {
if rawTable.Name == name && rawTable.Family == family {
return NewTable(this, rawTable), nil
}
}
return nil, ErrTableNotFound
}
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
if len(name) > MaxTableNameLength {
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
}
var rawTable = this.rawConn.AddTable(&nft.Table{
Family: family,
Name: name,
})
err := this.Commit()
if err != nil {
return nil, err
}
return NewTable(this, rawTable), nil
}
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv4)
}
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv6)
}
func (this *Conn) DeleteTable(name string, family TableFamily) error {
table, err := this.GetTable(name, family)
if err != nil {
if err == ErrTableNotFound {
return nil
}
return err
}
this.rawConn.DelTable(table.Raw())
return this.Commit()
}
func (this *Conn) Commit() error {
return this.rawConn.Flush()
}

View File

@@ -0,0 +1,78 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"os/exec"
"testing"
)
func TestConn_Test(t *testing.T) {
_, err := exec.LookPath("nft")
if err == nil {
t.Log("ok")
return
}
t.Log(err)
}
func TestConn_GetTable_NotFound(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_GetTable(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_AddTable(t *testing.T) {
var conn = nftables.NewConn()
{
table, err := conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
{
table, err := conn.AddIPv6Table("test_ipv6")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
}
func TestConn_DeleteTable(t *testing.T) {
var conn = nftables.NewConn()
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,8 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
type Element struct {
}

View File

@@ -0,0 +1,19 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import "errors"
var ErrTableNotFound = errors.New("table not found")
var ErrChainNotFound = errors.New("chain not found")
var ErrSetNotFound = errors.New("set not found")
var ErrRuleNotFound = errors.New("rule not found")
func IsNotFound(err error) bool {
if err == nil {
return false
}
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
}

View File

@@ -0,0 +1,18 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
)
type TableFamily = nft.TableFamily
const (
TableFamilyINet TableFamily = nft.TableFamilyINet
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
TableFamilyARP TableFamily = nft.TableFamilyARP
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
TableFamilyBridge TableFamily = nft.TableFamilyBridge
)

View File

@@ -0,0 +1,51 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
type Rule struct {
rawRule *nft.Rule
}
func NewRule(rawRule *nft.Rule) *Rule {
return &Rule{
rawRule: rawRule,
}
}
func (this *Rule) Raw() *nft.Rule {
return this.rawRule
}
func (this *Rule) LookupSetName() string {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Lookup)
if ok {
return exp.SetName
}
}
return ""
}
func (this *Rule) VerDict() expr.VerdictKind {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Verdict)
if ok {
return exp.Kind
}
}
return -100
}
func (this *Rule) Handle() uint64 {
return this.rawRule.Handle
}
func (this *Rule) UserData() []byte {
return this.rawRule.UserData
}

View File

@@ -0,0 +1,161 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
nft "github.com/google/nftables"
"net"
"strings"
"time"
)
const MaxSetNameLength = 15
type SetOptions struct {
Id uint32
HasTimeout bool
Timeout time.Duration
KeyType SetDataType
DataType SetDataType
Constant bool
Interval bool
Anonymous bool
IsMap bool
}
type ElementOptions struct {
Timeout time.Duration
}
type Set struct {
conn *Conn
rawSet *nft.Set
batch *SetBatch
}
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
return &Set{
conn: conn,
rawSet: rawSet,
batch: &SetBatch{
conn: conn,
rawSet: rawSet,
},
}
}
func (this *Set) Raw() *nft.Set {
return this.rawSet
}
func (this *Set) Name() string {
return this.rawSet.Name
}
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
// retry if exists
if strings.Contains(err.Error(), "file exists") {
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if deleteErr == nil {
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err == nil {
err = this.conn.Commit()
}
}
}
}
return err
}
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.AddElement(ipObj.To4(), options)
} else {
return this.AddElement(ipObj.To16(), options)
}
}
func (this *Set) DeleteElement(key []byte) error {
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
err = nil
}
}
return err
}
func (this *Set) DeleteIPElement(ip string) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.DeleteElement(ipObj.To4())
} else {
return this.DeleteElement(ipObj.To16())
}
}
func (this *Set) Batch() *SetBatch {
return this.batch
}
func (this *Set) GetIPElements() ([]string, error) {
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
if err != nil {
return nil, err
}
var result = []string{}
for _, element := range elements {
result = append(result, net.IP(element.Key).String())
}
return result, nil
}
// not work current time
/**func (this *Set) Flush() error {
this.conn.Raw().FlushSet(this.rawSet)
return this.conn.Commit()
}**/

View File

@@ -0,0 +1,36 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
)
type SetBatch struct {
conn *Conn
rawSet *nft.Set
}
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
}
func (this *SetBatch) DeleteElement(key []byte) error {
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
}
func (this *SetBatch) Commit() error {
return this.conn.Commit()
}

View File

@@ -0,0 +1,57 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import nft "github.com/google/nftables"
type SetDataType = nft.SetDatatype
var (
TypeInvalid = nft.TypeInvalid
TypeVerdict = nft.TypeVerdict
TypeNFProto = nft.TypeNFProto
TypeBitmask = nft.TypeBitmask
TypeInteger = nft.TypeInteger
TypeString = nft.TypeString
TypeLLAddr = nft.TypeLLAddr
TypeIPAddr = nft.TypeIPAddr
TypeIP6Addr = nft.TypeIP6Addr
TypeEtherAddr = nft.TypeEtherAddr
TypeEtherType = nft.TypeEtherType
TypeARPOp = nft.TypeARPOp
TypeInetProto = nft.TypeInetProto
TypeInetService = nft.TypeInetService
TypeICMPType = nft.TypeICMPType
TypeTCPFlag = nft.TypeTCPFlag
TypeDCCPPktType = nft.TypeDCCPPktType
TypeMHType = nft.TypeMHType
TypeTime = nft.TypeTime
TypeMark = nft.TypeMark
TypeIFIndex = nft.TypeIFIndex
TypeARPHRD = nft.TypeARPHRD
TypeRealm = nft.TypeRealm
TypeClassID = nft.TypeClassID
TypeUID = nft.TypeUID
TypeGID = nft.TypeGID
TypeCTState = nft.TypeCTState
TypeCTDir = nft.TypeCTDir
TypeCTStatus = nft.TypeCTStatus
TypeICMP6Type = nft.TypeICMP6Type
TypeCTLabel = nft.TypeCTLabel
TypePktType = nft.TypePktType
TypeICMPCode = nft.TypeICMPCode
TypeICMPV6Code = nft.TypeICMPV6Code
TypeICMPXCode = nft.TypeICMPXCode
TypeDevGroup = nft.TypeDevGroup
TypeDSCP = nft.TypeDSCP
TypeECN = nft.TypeECN
TypeFIBAddr = nft.TypeFIBAddr
TypeBoolean = nft.TypeBoolean
TypeCTEventBit = nft.TypeCTEventBit
TypeIFName = nft.TypeIFName
TypeIGMPType = nft.TypeIGMPType
TypeTimeDate = nft.TypeTimeDate
TypeTimeHour = nft.TypeTimeHour
TypeTimeDay = nft.TypeTimeDay
TypeCGroupV2 = nft.TypeCGroupV2
)

View File

@@ -0,0 +1,110 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables_test
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/iwind/TeaGo/types"
"github.com/mdlayher/netlink"
"net"
"testing"
"time"
)
func getIPv4Set(t *testing.T) *nftables.Set {
var table = getIPv4Table(t)
set, err := table.GetSet("test_ipv4_set")
if err != nil {
if err == nftables.ErrSetNotFound {
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
})
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return set
}
func TestSet_AddElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_DeleteElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Batch(t *testing.T) {
var batch = getIPv4Set(t).Batch()
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
var ipData = net.ParseIP(ip).To4()
//err := batch.DeleteElement(ipData)
//if err != nil {
// t.Fatal(err)
//}
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
if err != nil {
t.Fatal(err)
}
}
err := batch.Commit()
if err != nil {
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Add_Many(t *testing.T) {
var set = getIPv4Set(t)
for i := 0; i < 255; i++ {
t.Log(i)
for j := 0; j < 255; j++ {
var ip = "192.167." + types.String(i) + "." + types.String(j)
var ipData = net.ParseIP(ip).To4()
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
if err != nil {
t.Fatal(err)
}
if j%10 == 0 {
err = set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
}
err := set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
t.Log("ok")
}
/**func TestSet_Flush(t *testing.T) {
var set = getIPv4Set(t)
err := set.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}**/

View File

@@ -0,0 +1,157 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
"strings"
)
type Table struct {
conn *Conn
rawTable *nft.Table
}
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
return &Table{
conn: conn,
rawTable: rawTable,
}
}
func (this *Table) Raw() *nft.Table {
return this.rawTable
}
func (this *Table) Name() string {
return this.rawTable.Name
}
func (this *Table) Family() TableFamily {
return this.rawTable.Family
}
func (this *Table) GetChain(name string) (*Chain, error) {
rawChains, err := this.conn.Raw().ListChains()
if err != nil {
return nil, err
}
for _, rawChain := range rawChains {
// must compare table name
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
return NewChain(this.conn, this.rawTable, rawChain), nil
}
}
return nil, ErrChainNotFound
}
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
if len(name) > MaxChainNameLength {
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
}
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
Name: name,
Table: this.rawTable,
Hooknum: nft.ChainHookInput,
Priority: nft.ChainPriorityFilter,
Type: nft.ChainTypeFilter,
Policy: chainPolicy,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewChain(this.conn, this.rawTable, rawChain), nil
}
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
var policy = ChainPolicyAccept
return this.AddChain(name, &policy)
}
func (this *Table) AddDropChain(name string) (*Chain, error) {
var policy = ChainPolicyDrop
return this.AddChain(name, &policy)
}
func (this *Table) DeleteChain(name string) error {
chain, err := this.GetChain(name)
if err != nil {
if err == ErrChainNotFound {
return nil
}
return err
}
this.conn.Raw().DelChain(chain.Raw())
return this.conn.Commit()
}
func (this *Table) GetSet(name string) (*Set, error) {
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
return nil, ErrSetNotFound
}
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
if len(name) > MaxSetNameLength {
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
}
if options == nil {
options = &SetOptions{}
}
var rawSet = &nft.Set{
Table: this.rawTable,
ID: options.Id,
Name: name,
Anonymous: options.Anonymous,
Constant: options.Constant,
Interval: options.Interval,
IsMap: options.IsMap,
HasTimeout: options.HasTimeout,
Timeout: options.Timeout,
KeyType: options.KeyType,
DataType: options.DataType,
}
err := this.conn.Raw().AddSet(rawSet, nil)
if err != nil {
return nil, err
}
err = this.conn.Commit()
if err != nil {
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) DeleteSet(name string) error {
set, err := this.GetSet(name)
if err != nil {
if err == ErrSetNotFound {
return nil
}
return err
}
this.conn.Raw().DelSet(set.Raw())
return this.conn.Commit()
}
func (this *Table) Flush() error {
this.conn.Raw().FlushTable(this.rawTable)
return this.conn.Commit()
}

View File

@@ -0,0 +1,140 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"testing"
)
func getIPv4Table(t *testing.T) *nftables.Table {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return table
}
func TestTable_AddChain(t *testing.T) {
var table = getIPv4Table(t)
{
chain, err := table.AddChain("test_default_chain", nil)
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
{
chain, err := table.AddAcceptChain("test_accept_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
/**{
chain, err := table.AddDropChain("test_drop_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}**/
}
func TestTable_GetChain(t *testing.T) {
var table = getIPv4Table(t)
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
chain, err := table.GetChain(chainName)
if err != nil {
if err == nftables.ErrChainNotFound {
t.Log(chainName, ":", "not found")
} else {
t.Fatal(err)
}
} else {
t.Log(chainName, ":", chain)
}
}
}
func TestTable_DeleteChain(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteChain("test_default_chain")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_AddSet(t *testing.T) {
var table = getIPv4Table(t)
{
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
HasTimeout: false,
KeyType: nftables.TypeIPAddr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
{
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
HasTimeout: true,
//Timeout: 3600 * time.Second,
KeyType: nftables.TypeIP6Addr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
}
func TestTable_GetSet(t *testing.T) {
var table = getIPv4Table(t)
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
set, err := table.GetSet(setName)
if err != nil {
if err == nftables.ErrSetNotFound {
t.Log(setName, ": not found")
} else {
t.Fatal(err)
}
} else {
t.Log(setName, ":", set)
}
}
}
func TestTable_DeleteSet(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteSet("ipv4_black_set")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_Flush(t *testing.T) {
var table = getIPv4Table(t)
err := table.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -48,7 +48,7 @@ func (this *IPListDB) init() error {
if err != nil {
return err
}
remotelogs.Println("CACHE", "create cache dir '"+this.dir+"'")
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
db, err := sql.Open("sqlite3", "file:"+this.dir+"/ip_list.db?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")

View File

@@ -3,17 +3,29 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
)
// AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
// 放行lo
if ip == "127.0.0.1" {
return true, true
}
var ipLong = utils.IP2Long(ip)
if ipLong == 0 {
return false, false
}
// check node
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
return true, true
}
// check white lists
if GlobalWhiteIPList.Contains(ipLong) {
return true, true

View File

@@ -23,7 +23,7 @@ import (
"time"
)
const MaxQueueSize = 10240
const MaxQueueSize = 2048
// Task 单个指标任务
// 数据库存储:
@@ -58,7 +58,7 @@ type Task struct {
timeMap map[string]zero.Zero // time => bool
serverIdMapLocker sync.Mutex
statsMap map[string]*Stat
statsMap map[string]*Stat // 待写入队列hash => *Stat
statsLocker sync.Mutex
statsTicker *utils.Ticker
}

View File

@@ -1,32 +1,30 @@
package nodes
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/configs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/errors"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
"io"
"net"
"net/http"
"github.com/iwind/TeaGo/maps"
"net/url"
"os/exec"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
)
@@ -111,14 +109,12 @@ func (this *APIStream) loop() error {
err = this.handleStatCache(message)
case messageconfigs.MessageCodeCleanCache: // 清理缓存
err = this.handleCleanCache(message)
case messageconfigs.MessageCodePurgeCache: // 删除缓存
err = this.handlePurgeCache(message)
case messageconfigs.MessageCodePreheatCache: // 预热缓存
err = this.handlePreheatCache(message)
case messageconfigs.MessageCodeNewNodeTask: // 有新的任务
err = this.handleNewNodeTask(message)
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
err = this.handleCheckSystemdService(message)
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
err = this.handleCheckLocalFirewall(message)
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
err = this.handleChangeAPINode(message)
default:
@@ -328,224 +324,6 @@ func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error {
return nil
}
// 删除缓存
func (this *APIStream) handlePurgeCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.PurgeCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
// WEBP缓存
if msg.Type == "file" {
var keys = msg.Keys
for _, key := range keys {
keys = append(keys,
key+caches.SuffixMethod+"HEAD",
key+caches.SuffixWebP,
key+caches.SuffixPartial,
)
// TODO 根据实际缓存的内容进行组合
for _, encoding := range compressions.AllEncodings() {
keys = append(keys, key+caches.SuffixCompression+encoding)
keys = append(keys, key+caches.SuffixWebP+caches.SuffixCompression+encoding)
}
}
msg.Keys = keys
}
err = storage.Purge(msg.Keys, msg.Type)
if err != nil {
this.replyFail(message.RequestId, "purge keys failed: "+err.Error())
return err
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 预热缓存
func (this *APIStream) handlePreheatCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.PreheatCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
if len(msg.Keys) == 0 {
this.replyOk(message.RequestId, "ok")
return nil
}
wg := sync.WaitGroup{}
wg.Add(len(msg.Keys))
client := &http.Client{
Timeout: 30 * time.Second, // TODO 可以设置请求超时时间
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return net.Dial(network, "127.0.0.1:"+port)
},
MaxIdleConns: 4096,
MaxIdleConnsPerHost: 32,
MaxConnsPerHost: 32,
IdleConnTimeout: 2 * time.Minute,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
defer client.CloseIdleConnections()
errorMessages := []string{}
locker := sync.Mutex{}
for _, key := range msg.Keys {
go func(key string) {
defer wg.Done()
req, err := http.NewRequest(http.MethodGet, key, nil)
if err != nil {
locker.Lock()
errorMessages = append(errorMessages, "invalid url: "+key+": "+err.Error())
locker.Unlock()
return
}
// TODO 可以在管理界面自定义Header
req.Header.Set("X-Cache-Action", "preheat")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36")
req.Header.Set("Accept-Encoding", "gzip, deflate, br") // TODO 这里需要记录下缓存是否为gzip的
resp, err := client.Do(req)
if err != nil {
locker.Lock()
errorMessages = append(errorMessages, "request failed: "+key+": "+err.Error())
locker.Unlock()
return
}
if resp.StatusCode != 200 {
locker.Lock()
errorMessages = append(errorMessages, "request failed: "+key+": status code '"+strconv.Itoa(resp.StatusCode)+"'")
locker.Unlock()
return
}
defer func() {
_ = resp.Body.Close()
}()
// 检查最大内容长度
// TODO 需要解决Chunked Transfer Encoding的长度判断问题
maxSize := storage.Policy().MaxSizeBytes()
if maxSize > 0 && resp.ContentLength > maxSize {
locker.Lock()
errorMessages = append(errorMessages, "request failed: the content is too larger than policy setting")
locker.Unlock()
return
}
expiredAt := time.Now().Unix() + 8600
writer, err := storage.OpenWriter(key, expiredAt, 200, resp.ContentLength, -1, false) // TODO 可以设置缓存过期时间
if err != nil {
locker.Lock()
errorMessages = append(errorMessages, "open cache writer failed: "+key+": "+err.Error())
locker.Unlock()
return
}
buf := make([]byte, 16*1024)
isClosed := false
// 写入Header
for k, v := range resp.Header {
for _, v1 := range v {
_, err = writer.WriteHeader([]byte(k + ":" + v1 + "\n"))
if err != nil {
locker.Lock()
errorMessages = append(errorMessages, "write failed: "+key+": "+err.Error())
locker.Unlock()
return
}
}
}
// 写入Body
for {
n, err := resp.Body.Read(buf)
if n > 0 {
_, writerErr := writer.Write(buf[:n])
if writerErr != nil {
locker.Lock()
errorMessages = append(errorMessages, "write failed: "+key+": "+writerErr.Error())
locker.Unlock()
break
}
}
if err != nil {
if err == io.EOF {
err = writer.Close()
if err == nil {
storage.AddToList(&caches.Item{
Type: writer.ItemType(),
Key: key,
ExpiredAt: expiredAt,
})
}
isClosed = true
} else {
locker.Lock()
errorMessages = append(errorMessages, "read url failed: "+key+": "+err.Error())
locker.Unlock()
}
break
}
}
if !isClosed {
_ = writer.Close()
}
}(key)
}
wg.Wait()
if len(errorMessages) > 0 {
this.replyFail(message.RequestId, strings.Join(errorMessages, ", "))
return nil
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 处理配置变化
func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error {
select {
@@ -569,7 +347,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
return nil
}
cmd := utils.NewCommandExecutor()
var cmd = utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run()
@@ -585,6 +363,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
return nil
}
// 检查本地防火墙
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
err := json.Unmarshal(message.DataJSON, dataMessage)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return nil
}
// nft
if dataMessage.Name == "nftables" {
if runtime.GOOS != "linux" {
this.replyFail(message.RequestId, "not Linux system")
return nil
}
nft, err := exec.LookPath("nft")
if err != nil {
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
return nil
}
var cmd = exec.Command(nft, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")
return nil
}
var version = versionMatches[1]
var result = maps.Map{
"version": version,
}
var protectionConfig = sharedNodeConfig.DDOSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}
} else {
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
}
return nil
}
// 修改API地址
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
config, err := configs.LoadAPIConfig()
@@ -660,6 +495,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
}
// 回复成功并包含数据
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
}
// 获取缓存存取对象
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
cachePolicy := &serverconfigs.HTTPCachePolicy{}

View File

@@ -7,7 +7,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
@@ -21,8 +20,7 @@ import (
// ClientConn 客户端连接
type ClientConn struct {
once sync.Once
globalLimiter *ratelimit.Counter
once sync.Once
isTLS bool
hasDeadline bool
@@ -33,7 +31,7 @@ type ClientConn struct {
BaseClientConn
}
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
if quickClose {
// TCP
tcpConn, ok := conn.(*net.TCPConn)
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
}
}
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
err := this.rawConn.Close()
// 全局并发数限制
this.once.Do(func() {
if this.globalLimiter != nil {
this.globalLimiter.Release()
}
})
// 单个服务并发数限制
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
@@ -137,7 +128,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
var ip = this.RawIP()
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
var timestamp = utils.NextMinuteUnixTime()
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp)
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp, true)
var minAttempts = synFloodConfig.MinAttempts
if minAttempts < 5 {
minAttempts = 5

View File

@@ -3,16 +3,13 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"net"
)
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
// ClientListener 客户端网络监听
type ClientListener struct {
rawListener net.Listener
@@ -36,13 +33,8 @@ func (this *ClientListener) IsTLS() bool {
}
func (this *ClientListener) Accept() (net.Conn, error) {
// 限制并发连接数
var limiter = sharedConnectionsLimiter
limiter.Ack()
conn, err := this.rawListener.Accept()
if err != nil {
limiter.Release()
return nil, err
}
@@ -50,22 +42,30 @@ func (this *ClientListener) Accept() (net.Conn, error) {
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
canGoNext, _ := iplibrary.AllowIP(ip, 0)
var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if !canGoNext ||
(!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
if !canGoNext || beingDenied {
tcpConn, ok := conn.(*net.TCPConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
limiter.Release()
// 使用本地防火墙延长封禁
if beingDenied {
var fw = firewalls.Firewall()
if fw != nil && !fw.IsMock() {
_ = fw.DropSourceIP(ip, 60)
}
}
return this.Accept()
}
}
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
return NewClientConn(conn, this.isTLS, this.quickClose), nil
}
func (this *ClientListener) Close() error {

View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
type LingerConn interface {
SetLinger(sec int) error
}

View File

@@ -80,6 +80,13 @@ Loop:
return nil
}
// 发送到本地
if sharedHTTPAccessLogViewer.HasConns() {
for _, accessLog := range accessLogs {
sharedHTTPAccessLogViewer.Send(accessLog)
}
}
// 发送到API
if this.rpcClient == nil {
client, err := rpc.SharedRPC()

View File

@@ -0,0 +1,116 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net"
"os"
"sync"
"sync/atomic"
)
var sharedHTTPAccessLogViewer = NewHTTPAccessLogViewer()
// HTTPAccessLogViewer 本地访问日志浏览器
type HTTPAccessLogViewer struct {
sockFile string
listener net.Listener
connMap map[int64]net.Conn // connId => net.Conn
connId int64
locker sync.Mutex
}
// NewHTTPAccessLogViewer 获取新对象
func NewHTTPAccessLogViewer() *HTTPAccessLogViewer {
return &HTTPAccessLogViewer{
sockFile: os.TempDir() + "/" + teaconst.AccessLogSockName,
connMap: map[int64]net.Conn{},
}
}
// Start 启动
func (this *HTTPAccessLogViewer) Start() error {
this.locker.Lock()
defer this.locker.Unlock()
if this.listener == nil {
// remove if exists
_ = os.Remove(this.sockFile)
// start listening
listener, err := net.Listen("unix", this.sockFile)
if err != nil {
return err
}
this.listener = listener
go func() {
for {
conn, err := this.listener.Accept()
if err != nil {
remotelogs.Error("ACCESS_LOG", "start local reading failed: "+err.Error())
break
}
this.locker.Lock()
var connId = this.nextConnId()
this.connMap[connId] = conn
go func() {
this.startReading(conn, connId)
}()
this.locker.Unlock()
}
}()
}
return nil
}
// HasConns 检查是否有连接
func (this *HTTPAccessLogViewer) HasConns() bool {
this.locker.Lock()
defer this.locker.Unlock()
return len(this.connMap) > 0
}
// Send 发送日志
func (this *HTTPAccessLogViewer) Send(accessLog *pb.HTTPAccessLog) {
var conns = []net.Conn{}
this.locker.Lock()
for _, conn := range this.connMap {
conns = append(conns, conn)
}
this.locker.Unlock()
if len(conns) == 0 {
return
}
for _, conn := range conns {
// ignore error
_, _ = conn.Write([]byte(accessLog.RemoteAddr + " [" + accessLog.TimeLocal + "] \"" + accessLog.RequestMethod + " " + accessLog.Scheme + "://" + accessLog.Host + accessLog.RequestURI + " " + accessLog.Proto + "\" " + types.String(accessLog.Status) + " - " + fmt.Sprintf("%.2fms", accessLog.RequestTime*1000) + "\n"))
}
}
func (this *HTTPAccessLogViewer) nextConnId() int64 {
return atomic.AddInt64(&this.connId, 1)
}
func (this *HTTPAccessLogViewer) startReading(conn net.Conn, connId int64) {
var buf = make([]byte, 1024)
for {
_, err := conn.Read(buf)
if err != nil {
this.locker.Lock()
delete(this.connMap, connId)
this.locker.Unlock()
break
}
}
}

View File

@@ -0,0 +1,247 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"io"
"io/ioutil"
"net"
"net/http"
"regexp"
"strings"
"time"
)
func init() {
events.On(events.EventStart, func() {
goman.New(func() {
SharedHTTPCacheTaskManager.Start()
})
})
}
var SharedHTTPCacheTaskManager = NewHTTPCacheTaskManager()
// HTTPCacheTaskManager 缓存任务管理
type HTTPCacheTaskManager struct {
ticker *time.Ticker
httpClient *http.Client
protocolReg *regexp.Regexp
taskQueue chan *pb.PurgeServerCacheRequest
}
func NewHTTPCacheTaskManager() *HTTPCacheTaskManager {
var duration = 30 * time.Second
if Tea.IsTesting() {
duration = 10 * time.Second
}
return &HTTPCacheTaskManager{
ticker: time.NewTicker(duration),
httpClient: &http.Client{
Timeout: 10 * time.Minute, // TODO 可以设置请求超时时间
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return net.Dial(network, "127.0.0.1:"+port)
},
MaxIdleConns: 128,
MaxIdleConnsPerHost: 32,
MaxConnsPerHost: 32,
IdleConnTimeout: 2 * time.Minute,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
protocolReg: regexp.MustCompile(`^(?i)(http|https)://`),
taskQueue: make(chan *pb.PurgeServerCacheRequest, 1024),
}
}
func (this *HTTPCacheTaskManager) Start() {
// task queue
goman.New(func() {
rpcClient, _ := rpc.SharedRPC()
if rpcClient != nil {
for taskReq := range this.taskQueue {
_, err := rpcClient.ServerRPC().PurgeServerCache(rpcClient.Context(), taskReq)
if err != nil {
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "create purge task failed: "+err.Error())
}
}
}
})
// Loop
for range this.ticker.C {
err := this.Loop()
if err != nil {
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "execute task failed: "+err.Error())
}
}
}
func (this *HTTPCacheTaskManager) Loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.HTTPCacheTaskKeyRPC().FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
if err != nil {
return err
}
var keys = resp.HttpCacheTaskKeys
if len(keys) == 0 {
return nil
}
var pbResults = []*pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{}
for _, key := range keys {
err = this.processKey(key)
var pbResult = &pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{
Id: key.Id,
NodeClusterId: key.NodeClusterId,
Error: "",
}
if err != nil {
pbResult.Error = err.Error()
}
pbResults = append(pbResults, pbResult)
}
_, err = rpcClient.HTTPCacheTaskKeyRPC().UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
if err != nil {
return err
}
return nil
}
func (this *HTTPCacheTaskManager) PushTaskKeys(keys []string) {
select {
case this.taskQueue <- &pb.PurgeServerCacheRequest{
Keys: keys,
Prefixes: nil,
}:
default:
}
}
func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
switch key.Type {
case "purge":
var storages = caches.SharedManager.FindAllStorages()
for _, storage := range storages {
switch key.KeyType {
case "key":
var cacheKeys = []string{key.Key}
if strings.HasPrefix(key.Key, "http://") {
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "http://", "https://", 1))
} else if strings.HasPrefix(key.Key, "https://") {
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "https://", "http://", 1))
}
// TODO 提升效率
for _, cacheKey := range cacheKeys {
var subKeys = []string{
cacheKey,
cacheKey + caches.SuffixMethod + "HEAD",
cacheKey + caches.SuffixWebP,
cacheKey + caches.SuffixPartial,
}
// TODO 根据实际缓存的内容进行组合
for _, encoding := range compressions.AllEncodings() {
subKeys = append(subKeys, cacheKey+caches.SuffixCompression+encoding)
subKeys = append(subKeys, cacheKey+caches.SuffixWebP+caches.SuffixCompression+encoding)
}
err := storage.Purge(subKeys, "file")
if err != nil {
return err
}
}
case "prefix":
var prefixes = []string{key.Key}
if strings.HasPrefix(key.Key, "http://") {
prefixes = append(prefixes, strings.Replace(key.Key, "http://", "https://", 1))
} else if strings.HasPrefix(key.Key, "https://") {
prefixes = append(prefixes, strings.Replace(key.Key, "https://", "http://", 1))
}
err := storage.Purge(prefixes, "dir")
if err != nil {
return err
}
}
}
case "fetch":
err := this.fetchKey(key)
if err != nil {
return err
}
default:
return errors.New("invalid operation type '" + key.Type + "'")
}
return nil
}
// TODO 增加失败重试
// TODO 使用并发操作
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
var fullKey = key.Key
if !this.protocolReg.MatchString(fullKey) {
fullKey = "https://" + fullKey
}
req, err := http.NewRequest(http.MethodGet, fullKey, nil)
if err != nil {
return errors.New("invalid url: " + fullKey + ": " + err.Error())
}
// TODO 可以在管理界面自定义Header
req.Header.Set("X-Edge-Cache-Action", "fetch")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36") // TODO 可以定义
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
resp, err := this.httpClient.Do(req)
if err != nil {
return errors.New("request failed: " + fullKey + ": " + err.Error())
}
defer func() {
_ = resp.Body.Close()
}()
// 读取内容,以便于生成缓存
_, _ = io.Copy(ioutil.Discard, resp.Body)
// 处理502
if resp.StatusCode == http.StatusBadGateway {
return errors.New("read origin site timeout")
}
return nil
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/nodes"
"testing"
)
func TestHTTPCacheTaskManager_Loop(t *testing.T) {
// initialize cache policies
config, err := nodeconfigs.SharedNodeConfig()
if err != nil {
t.Fatal(err)
}
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
var manager = nodes.NewHTTPCacheTaskManager()
err = manager.Loop()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -104,39 +104,41 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
}
}
var transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 支持TOA的连接
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
if conn != nil || err != nil {
return conn, err
}
var transport = &HTTPClientTransport{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 支持TOA的连接
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
if conn != nil || err != nil {
return conn, err
}
// 普通的连接
conn, err = (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, originAddr)
if err != nil {
return nil, err
}
// 普通的连接
conn, err = (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, originAddr)
if err != nil {
return nil, err
}
// 处理PROXY protocol
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
if err != nil {
return nil, err
}
// 处理PROXY protocol
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
if err != nil {
return nil, err
}
return conn, nil
return conn, nil
},
MaxIdleConns: 0,
MaxIdleConnsPerHost: idleConns,
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
TLSClientConfig: tlsConfig,
Proxy: nil,
},
MaxIdleConns: 0,
MaxIdleConnsPerHost: idleConns,
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
TLSClientConfig: tlsConfig,
Proxy: nil,
}
rawClient = &http.Client{

View File

@@ -0,0 +1,26 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
)
const emptyHTTPLocation = "/$EmptyHTTPLocation$"
type HTTPClientTransport struct {
*http.Transport
}
func (this *HTTPClientTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := this.Transport.RoundTrip(req)
if err != nil {
return resp, err
}
// 检查在跳转相关状态中Location是否存在
if httpStatusIsRedirect(resp.StatusCode) && len(resp.Header.Get("Location")) == 0 {
resp.Header.Set("Location", emptyHTTPLocation)
}
return resp, nil
}

View File

@@ -170,9 +170,10 @@ func (this *HTTPRequest) Do() {
// ACME
// TODO 需要配置是否启用ACME检测
if strings.HasPrefix(this.rawURI, "/.well-known/acme-challenge/") {
this.doACME()
this.doEnd()
return
if this.doACME() {
this.doEnd()
return
}
}
}
@@ -265,10 +266,19 @@ func (this *HTTPRequest) doBegin() {
}
// UAM
if !isHealthCheck && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
if this.doUAM() {
this.doEnd()
return
if !isHealthCheck {
if this.web.UAM != nil {
if this.web.UAM.IsOn {
if this.doUAM() {
this.doEnd()
return
}
}
} else if this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
if this.doUAM() {
this.doEnd()
return
}
}
}
@@ -521,6 +531,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
}
}
// UAM
if web.UAM != nil && (web.UAM.IsPrior || isTop) {
this.web.UAM = web.UAM
}
// 重写规则
if len(web.RewriteRefs) > 0 {
for index, ref := range web.RewriteRefs {
@@ -1033,7 +1048,7 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
}
// X-Forwarded-For
forwardedFor := this.RawReq.Header.Get("X-Forwarded-For")
var forwardedFor = this.RawReq.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
@@ -1447,7 +1462,12 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
header["X-Forwarded-For"] = []string{strings.Join(forwardedFor, ", ") + ", " + remoteAddr}
}
} else {
header["X-Forwarded-For"] = []string{remoteAddr}
var clientRemoteAddr = this.requestRemoteAddr(true)
if len(clientRemoteAddr) > 0 && clientRemoteAddr != remoteAddr {
header["X-Forwarded-For"] = []string{clientRemoteAddr + ", " + remoteAddr}
} else {
header["X-Forwarded-For"] = []string{remoteAddr}
}
}
}

View File

@@ -4,34 +4,36 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"net/http"
"path/filepath"
)
func (this *HTTPRequest) doACME() {
func (this *HTTPRequest) doACME() (shouldStop bool) {
// TODO 对请求进行校验,防止恶意攻击
token := filepath.Base(this.RawReq.URL.Path)
var token = filepath.Base(this.RawReq.URL.Path)
if token == "acme-challenge" || len(token) <= 32 {
this.writer.WriteHeader(http.StatusNotFound)
return
return false
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("RPC", "[ACME]rpc failed: "+err.Error())
return
return false
}
keyResp, err := rpcClient.ACMEAuthenticationRPC().FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
if err != nil {
remotelogs.Error("RPC", "[ACME]read key for token failed: "+err.Error())
return
return false
}
if len(keyResp.Key) == 0 {
this.writer.WriteHeader(http.StatusNotFound)
} else {
this.writer.Header().Set("Content-Type", "text/plain")
_, _ = this.writer.WriteString(keyResp.Key)
return false
}
this.tags = append(this.tags, "ACME")
this.writer.Header().Set("Content-Type", "text/plain")
_, _ = this.writer.WriteString(keyResp.Key)
return true
}

View File

@@ -3,12 +3,9 @@ package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
"github.com/iwind/TeaGo/types"
@@ -33,11 +30,6 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return
}
// 判断是否在预热
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Cache-Action") == "preheat" {
return
}
// 添加 X-Cache Header
var addStatusHeader = this.web.Cache.AddStatusHeader
if addStatusHeader {
@@ -89,6 +81,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return
}
// 是否正在Purge
var isPurging = this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey
if isPurging {
this.RawReq.Method = http.MethodGet
}
// 校验请求
if !this.cacheRef.MatchRequest(this.RawReq) {
this.cacheRef = nil
@@ -136,8 +134,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
this.writer.cacheStorage = storage
// 如果正在预热,则不读取缓存,等待下一个步骤重新生成
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Edge-Cache-Action") == "fetch" {
return
}
// 判断是否在Purge
if this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey {
if isPurging {
this.varMapping["cache.status"] = "PURGE"
var subKeys = []string{
@@ -159,22 +162,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
}
// 通过API节点清除别节点上的的Key
// TODO 改为队列不需要每个请求都使用goroutine
goman.New(func() {
rpcClient, err := rpc.SharedRPC()
if err == nil {
for _, rpcServerService := range rpcClient.ServerRPCList() {
_, err = rpcServerService.PurgeServerCache(rpcClient.Context(), &pb.PurgeServerCacheRequest{
Domains: []string{this.ReqHost},
Keys: []string{key},
Prefixes: nil,
})
if err != nil {
remotelogs.Error("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
}
}
}
})
SharedHTTPCacheTaskManager.PushTaskKeys([]string{key})
return true
}
@@ -248,6 +236,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if reader == nil {
reader, err = storage.OpenReader(key, useStale, false)
if err != nil && this.cacheRef.AllowPartialContent {
// 尝试读取分片的缓存内容
if len(rangeHeader) == 0 {
// 默认读取开头
rangeHeader = "bytes=0-"
}
pReader, ranges := this.tryPartialReader(storage, key, useStale, rangeHeader)
if pReader != nil {
isPartialCache = true

View File

@@ -21,8 +21,13 @@ func (this *HTTPRequest) doHealthCheck(key string, isHealthCheck *bool) (stop bo
}
*isHealthCheck = true
if !data.GetBool("accessLogIsOn") {
this.disableLog = true
}
if data.GetBool("onlyBasicRequest") {
return true
}
return
}

View File

@@ -2,9 +2,18 @@
package nodes
import "net/http"
import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"net/http"
)
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
// 是否在全局名单中
_, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
if isInAllowedList {
return false
}
// 检查请求Body尺寸
// TODO 处理分片提交的内容
if this.web.RequestLimit.MaxBodyBytes() > 0 &&

View File

@@ -282,16 +282,32 @@ func (this *HTTPRequest) doReverseProxy() {
// 替换Location中的源站地址
var locationHeader = resp.Header.Get("Location")
if len(locationHeader) > 0 {
locationURL, err := url.Parse(locationHeader)
if err == nil && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
locationURL.Host = this.ReqHost
if this.IsHTTP {
locationURL.Scheme = "http"
} else if this.IsHTTPS {
locationURL.Scheme = "https"
}
// 空Location处理
if locationHeader == emptyHTTPLocation {
resp.Header.Del("Location")
} else {
// 自动修正Location中的源站地址
locationURL, err := url.Parse(locationHeader)
if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
locationURL.Host = this.ReqHost
resp.Header.Set("Location", locationURL.String())
var oldScheme = locationURL.Scheme
// 尝试和当前Scheme一致
if this.IsHTTP {
locationURL.Scheme = "http"
} else if this.IsHTTPS {
locationURL.Scheme = "https"
}
// 如果和当前URL一样则可能是http -> https防止无限循环
if locationURL.String() == this.URL() {
locationURL.Scheme = oldScheme
resp.Header.Set("Location", locationURL.String())
} else {
resp.Header.Set("Location", locationURL.String())
}
}
}
}
@@ -312,7 +328,12 @@ func (this *HTTPRequest) doReverseProxy() {
// 是否有内容
if resp.ContentLength == 0 && len(resp.TransferEncoding) == 0 {
// 即使内容为0也需要读取一次以便于触发相关事件
var buf = utils.BytePool4k.Get()
_, _ = io.CopyBuffer(this.writer, resp.Body, buf)
utils.BytePool4k.Put(buf)
_ = resp.Body.Close()
this.writer.SetOk()
return
}

View File

@@ -51,7 +51,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
rootDir := this.web.Root.Dir
var rootDir = this.web.Root.Dir
if this.web.Root.HasVariables() {
rootDir = this.Format(rootDir)
}
@@ -59,9 +59,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
rootDir = Tea.Root + Tea.DS + rootDir
}
requestPath := this.uri
var requestPath = this.uri
questionMarkIndex := strings.Index(this.uri, "?")
var questionMarkIndex = strings.Index(this.uri, "?")
if questionMarkIndex > -1 {
requestPath = this.uri[:questionMarkIndex]
}
@@ -75,7 +75,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
if err == nil {
requestPath = p
} else {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
}
}
@@ -92,8 +94,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
}
filename := strings.Replace(requestPath, "/", Tea.DS, -1)
filePath := ""
var filename = strings.Replace(requestPath, "/", Tea.DS, -1)
var filePath = ""
if len(filename) > 0 && filename[0:1] == Tea.DS {
filePath = rootDir + filename
} else {
@@ -113,7 +115,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return
} else {
this.write50x(err, http.StatusInternalServerError, true)
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
@@ -142,7 +146,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return
} else {
this.write50x(err, http.StatusInternalServerError, true)
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
@@ -152,24 +158,24 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 响应header
respHeader := this.writer.Header()
var respHeader = this.writer.Header()
// mime type
contentType := ""
var contentType = ""
if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
ext := filepath.Ext(filePath)
var ext = filepath.Ext(filePath)
if len(ext) > 0 {
mimeType := mime.TypeByExtension(ext)
if len(mimeType) > 0 {
semicolonIndex := strings.Index(mimeType, ";")
mimeTypeKey := mimeType
var semicolonIndex = strings.Index(mimeType, ";")
var mimeTypeKey = mimeType
if semicolonIndex > 0 {
mimeTypeKey = mimeType[:semicolonIndex]
}
if _, found := textMimeMap[mimeTypeKey]; found {
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
charset := this.web.Charset.Charset
var charset = this.web.Charset.Charset
if this.web.Charset.IsUpper {
charset = strings.ToUpper(charset)
}
@@ -197,7 +203,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 支持 ETag
eTag := "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
var eTag = "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
if len(respHeader.Get("ETag")) == 0 {
respHeader.Set("ETag", eTag)
}
@@ -227,7 +233,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持Range
respHeader.Set("Accept-Ranges", "bytes")
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
supportRange := true
var supportRange = true
if ok {
supportRange = false
for _, v := range ifRangeHeaders {
@@ -244,7 +250,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持Range
var ranges = []rangeutils.Range{}
if supportRange {
contentRange := this.RawReq.Header.Get("Range")
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
@@ -277,7 +283,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
}
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
fileReader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
if err != nil {
this.write50x(err, http.StatusInternalServerError, true)
return true
@@ -291,12 +297,16 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
this.cacheRef = nil // 不支持缓存
}
this.writer.Prepare(nil, fileSize, http.StatusOK, true)
var resp = &http.Response{
ContentLength: fileSize,
Body: fileReader,
StatusCode: http.StatusOK,
}
this.writer.Prepare(resp, fileSize, http.StatusOK, true)
pool := this.bytePool(fileSize)
buf := pool.Get()
var pool = this.bytePool(fileSize)
var buf = pool.Get()
defer func() {
_ = reader.Close()
pool.Put(buf)
}()
@@ -304,12 +314,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize)))
this.writer.WriteHeader(http.StatusPartialContent)
ok, err := httpRequestReadRange(reader, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
ok, err := httpRequestReadRange(resp.Body, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
_, err := this.writer.Write(buf[:n])
return err
})
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if !ok {
@@ -318,7 +330,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
} else if len(ranges) > 1 {
boundary := httpRequestGenBoundary()
var boundary = httpRequestGenBoundary()
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
this.writer.WriteHeader(http.StatusPartialContent)
@@ -330,30 +342,38 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
}
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n")
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if len(contentType) > 0 {
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
ok, err := httpRequestReadRange(reader, buf, r.Start(), r.End(), func(buf []byte, n int) error {
ok, err := httpRequestReadRange(resp.Body, buf, r.Start(), r.End(), func(buf []byte, n int) error {
_, err := this.writer.Write(buf[:n])
return err
})
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if !ok {
@@ -365,14 +385,17 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
} else {
_, err = io.CopyBuffer(this.writer, reader, buf)
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
@@ -400,7 +423,9 @@ func (this *HTTPRequest) findIndexFile(dir string) (filename string, stat os.Fil
if strings.Contains(index, "*") {
indexFiles, err := filepath.Glob(dir + Tea.DS + index)
if err != nil {
logs.Error(err)
if !this.canIgnore(err) {
logs.Error(err)
}
this.addError(err)
continue
}

View File

@@ -194,7 +194,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
// 规则测试
w := sharedWAFManager.FindWAF(firewallPolicy.Id)
w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
if w == nil {
return
}
@@ -261,7 +261,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
w := sharedWAFManager.FindWAF(firewallPolicy.Id)
w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
if w == nil {
return
}

View File

@@ -2,7 +2,6 @@ package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
"net/http"
@@ -20,7 +19,7 @@ func (this *HTTPRequest) doWebsocket() {
// TODO 实现handshakeTimeout
// 校验来源
requestOrigin := this.RawReq.Header.Get("Origin")
var requestOrigin = this.RawReq.Header.Get("Origin")
if len(requestOrigin) > 0 {
u, err := url.Parse(requestOrigin)
if err == nil {
@@ -34,7 +33,7 @@ func (this *HTTPRequest) doWebsocket() {
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
newRequestOrigin := this.web.Websocket.RequestOrigin
var newRequestOrigin = this.web.Websocket.RequestOrigin
if this.web.Websocket.RequestOriginHasVariables() {
newRequestOrigin = this.Format(newRequestOrigin)
}
@@ -45,8 +44,21 @@ func (this *HTTPRequest) doWebsocket() {
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
if err != nil {
this.write50x(err, http.StatusBadGateway, false)
// 增加失败次数
SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
return
}
if !this.origin.IsOk {
SharedOriginStateManager.Success(this.origin, func() {
this.reverseProxy.ResetScheduling()
})
}
defer func() {
_ = originConn.Close()
}()
@@ -66,7 +78,7 @@ func (this *HTTPRequest) doWebsocket() {
_ = clientConn.Close()
}()
goman.New(func() {
go func() {
var buf = utils.BytePool4k.Get()
defer utils.BytePool4k.Put(buf)
for {
@@ -84,6 +96,6 @@ func (this *HTTPRequest) doWebsocket() {
}
_ = clientConn.Close()
_ = originConn.Close()
})
}()
_, _ = io.Copy(originConn, clientConn)
}

View File

@@ -312,9 +312,10 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
if !caches.CanIgnoreErr(err) {
remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
this.Header().Set("X-Cache", "BYPASS, write cache failed")
} else {
this.Header().Set("X-Cache", "BYPASS, "+err.Error())
}
this.Header().Set("X-Cache", "BYPASS, too many requests")
return
}
this.cacheWriter = cacheWriter
@@ -448,7 +449,9 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
this.rawReader = cacheReader
cacheReader.OnFail(func(err error) {
_ = this.cacheWriter.Discard()
if this.cacheWriter != nil {
_ = this.cacheWriter.Discard()
}
this.cacheWriter = nil
})
cacheReader.OnEOF(func() {
@@ -836,7 +839,7 @@ func (this *HTTPWriter) HeaderData() []byte {
return nil
}
resp := &http.Response{}
var resp = &http.Response{}
resp.Header = this.Header()
if this.statusCode == 0 {
this.statusCode = http.StatusOK
@@ -859,6 +862,70 @@ func (this *HTTPWriter) SetOk() {
// Close 关闭
func (this *HTTPWriter) Close() {
this.finishWebP()
this.finishRequest()
this.finishCache()
this.finishCompression()
// 统计
if this.sentBodyBytes == 0 {
this.sentBodyBytes = this.counterWriter.TotalBytes()
}
}
// Hijack Hijack
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
hijack, ok := this.rawWriter.(http.Hijacker)
if ok {
return hijack.Hijack()
}
return
}
// Flush Flush
func (this *HTTPWriter) Flush() {
flusher, ok := this.rawWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
// DelayRead 是否延迟读取Reader
func (this *HTTPWriter) DelayRead() bool {
return this.delayRead
}
// 计算stale时长
func (this *HTTPWriter) calculateStaleLife() int {
var staleLife = 600 // TODO 可以在缓存策略里设置此时间
var staleConfig = this.req.web.Cache.Stale
if staleConfig != nil && staleConfig.IsOn {
// 从Header中读取stale-if-error
var isDefinedInHeader = false
if staleConfig.SupportStaleIfErrorHeader {
var cacheControl = this.GetHeader("Cache-Control")
var pieces = strings.Split(cacheControl, ",")
for _, piece := range pieces {
var eqIndex = strings.Index(piece, "=")
if eqIndex > 0 && strings.TrimSpace(piece[:eqIndex]) == "stale-if-error" {
// 这里预示着如果stale-if-error=0可以关闭stale功能
staleLife = types.Int(strings.TrimSpace(piece[eqIndex+1:]))
isDefinedInHeader = true
break
}
}
}
// 自定义
if !isDefinedInHeader && staleConfig.Life != nil {
staleLife = types.Int(staleConfig.Life.Duration().Seconds())
}
}
return staleLife
}
// 结束WebP
func (this *HTTPWriter) finishWebP() {
// 处理WebP
if this.webpIsEncoding {
var webpCacheWriter caches.Writer
@@ -919,6 +986,7 @@ func (this *HTTPWriter) Close() {
}
if err != nil {
// 发生了错误终止处理
return
}
@@ -948,7 +1016,7 @@ func (this *HTTPWriter) Close() {
//webpConfig.SetLossless(1)
webpConfig.SetQuality(f)
timeline := 0
var timeline = 0
for i, img := range gifImage.Image {
err = anim.AddFrame(img, timeline, webpConfig)
@@ -988,15 +1056,10 @@ func (this *HTTPWriter) Close() {
}
}
}
}
if this.writer != nil {
_ = this.writer.Close()
}
if this.rawReader != nil {
_ = this.rawReader.Close()
}
// 结束缓存相关处理
func (this *HTTPWriter) finishCache() {
// 缓存
if this.cacheWriter != nil {
if this.isOk && this.cacheIsFinished {
@@ -1054,7 +1117,10 @@ func (this *HTTPWriter) Close() {
}
}
}
}
// 结束压缩相关处理
func (this *HTTPWriter) finishCompression() {
if this.compressionCacheWriter != nil {
if this.isOk {
err := this.compressionCacheWriter.Close()
@@ -1075,59 +1141,15 @@ func (this *HTTPWriter) Close() {
_ = this.compressionCacheWriter.Discard()
}
}
}
if this.sentBodyBytes == 0 {
this.sentBodyBytes = this.counterWriter.TotalBytes()
// 最终关闭
func (this *HTTPWriter) finishRequest() {
if this.writer != nil {
_ = this.writer.Close()
}
if this.rawReader != nil {
_ = this.rawReader.Close()
}
}
// Hijack Hijack
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
hijack, ok := this.rawWriter.(http.Hijacker)
if ok {
return hijack.Hijack()
}
return
}
// Flush Flush
func (this *HTTPWriter) Flush() {
flusher, ok := this.rawWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
// DelayRead 是否延迟读取Reader
func (this *HTTPWriter) DelayRead() bool {
return this.delayRead
}
// 计算stale时长
func (this *HTTPWriter) calculateStaleLife() int {
var staleLife = 600 // TODO 可以在缓存策略里设置此时间
var staleConfig = this.req.web.Cache.Stale
if staleConfig != nil && staleConfig.IsOn {
// 从Header中读取stale-if-error
var isDefinedInHeader = false
if staleConfig.SupportStaleIfErrorHeader {
var cacheControl = this.GetHeader("Cache-Control")
var pieces = strings.Split(cacheControl, ",")
for _, piece := range pieces {
var eqIndex = strings.Index(piece, "=")
if eqIndex > 0 && strings.TrimSpace(piece[:eqIndex]) == "stale-if-error" {
// 这里预示着如果stale-if-error=0可以关闭stale功能
staleLife = types.Int(strings.TrimSpace(piece[eqIndex+1:]))
isDefinedInHeader = true
break
}
}
}
// 自定义
if !isDefinedInHeader && staleConfig.Life != nil {
staleLife = types.Int(staleConfig.Life.Duration().Seconds())
}
}
return staleLife
}

View File

@@ -3,6 +3,7 @@ package nodes
import (
"context"
"crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/zero"
@@ -170,34 +171,8 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
server, serverName := this.findNamedServer(domain)
if server == nil {
server = this.findServerWithCNAME(domain)
if server == nil {
// 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
httpAllConfig := sharedNodeConfig.GlobalConfig.HTTPAll
mismatchAction := httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Code == "page" {
if mismatchAction.Options != nil {
rawWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
rawWriter.WriteHeader(mismatchAction.Options.GetInt("statusCode"))
_, _ = rawWriter.Write([]byte(mismatchAction.Options.GetString("contentHTML")))
} else {
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
}
return
} else {
hijacker, ok := rawWriter.(http.Hijacker)
if ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
_ = conn.Close()
return
}
}
}
}
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
this.handleMismatch(rawReq, rawWriter)
return
} else {
serverName = domain
@@ -205,7 +180,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
}
// 包装新请求对象
req := &HTTPRequest{
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: rawWriter,
ReqServer: server,
@@ -220,6 +195,48 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
req.Do()
}
// 处理域名不匹配的情况
func (this *HTTPListener) handleMismatch(rawReq *http.Request, rawWriter http.ResponseWriter) {
// TODO 需要记录访问记录和防止CC
// 是否为健康检查
var healthCheckKey = rawReq.Header.Get(serverconfigs.HealthCheckHeaderName)
if len(healthCheckKey) > 0 {
_, err := nodeutils.Base64DecodeMap(healthCheckKey)
if err == nil {
rawWriter.WriteHeader(http.StatusOK)
return
}
}
// 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
var httpAllConfig = sharedNodeConfig.GlobalConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Code == "page" {
if mismatchAction.Options != nil {
rawWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
rawWriter.WriteHeader(mismatchAction.Options.GetInt("statusCode"))
_, _ = rawWriter.Write([]byte(mismatchAction.Options.GetString("contentHTML")))
} else {
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
}
return
} else {
hijacker, ok := rawWriter.(http.Hijacker)
if ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
_ = conn.Close()
return
}
}
}
}
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
}
func (this *HTTPListener) isIP(host string) bool {
// IPv6
if strings.Index(host, "[") > -1 {

View File

@@ -257,6 +257,12 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
return
}
// 检查状态
err = exec.Command(firewallCmd, "--state").Run()
if err != nil {
return
}
remotelogs.Println("FIREWALLD", "open ports automatically")
for _, port := range ports {
{

View File

@@ -21,7 +21,7 @@ type TCPListener struct {
}
func (this *TCPListener) Serve() error {
listener := this.Listener
var listener = this.Listener
if this.Group.IsTLS() {
listener = tls.NewListener(listener, this.buildTLSConfig())
}
@@ -52,14 +52,29 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
}
func (this *TCPListener) handleConn(conn net.Conn) error {
firstServer := this.Group.FirstServer()
if firstServer == nil {
var server = this.Group.FirstServer()
if server == nil {
return errors.New("no server available")
}
if firstServer.ReverseProxy == nil {
if server.ReverseProxy == nil {
return errors.New("no ReverseProxy configured for the server")
}
// 是否已达到流量限制
if this.reachedTrafficLimit() {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
// TODO 使用系统防火墙drop当前端口的数据包一段时间1分钟
// 不能使用阻止IP的方法因为边缘节点只上有可能还有别的代理服务
return nil
}
// 记录域名排行
tlsConn, ok := conn.(*tls.Conn)
var recordStat = false
@@ -67,17 +82,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
var serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String())
if err != nil {
return err
}
@@ -88,17 +103,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// PROXY Protocol
if firstServer.ReverseProxy != nil &&
firstServer.ReverseProxy.ProxyProtocol != nil &&
firstServer.ReverseProxy.ProxyProtocol.IsOn &&
(firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
if server.ReverseProxy != nil &&
server.ReverseProxy.ProxyProtocol != nil &&
server.ReverseProxy.ProxyProtocol.IsOn &&
(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = conn.RemoteAddr()
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.TCPv6
}
header := proxyproto.Header{
Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version),
var header = proxyproto.Header{
Version: byte(server.ReverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: remoteAddr,
@@ -113,7 +128,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 从源站读取
goman.New(func() {
originBuffer := utils.BytePool16k.Get()
var originBuffer = utils.BytePool16k.Get()
defer func() {
utils.BytePool16k.Put(originBuffer)
}()
@@ -127,8 +142,8 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// 记录流量
if firstServer != nil {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
if server != nil {
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {
@@ -139,11 +154,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
})
// 从客户端读取
clientBuffer := utils.BytePool16k.Get()
var clientBuffer = utils.BytePool16k.Get()
defer func() {
utils.BytePool16k.Put(clientBuffer)
}()
for {
// 是否已达到流量限制
if this.reachedTrafficLimit() {
closer()
return nil
}
n, err := conn.Read(clientBuffer)
if n > 0 {
_, err = originConn.Write(clientBuffer[:n])
@@ -188,3 +209,12 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
err = errors.New("no origin can be used")
return
}
// 检查是否已经达到流量限制
func (this *TCPListener) reachedTrafficLimit() bool {
var server = this.Group.FirstServer()
if server == nil {
return true
}
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/configs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
@@ -15,12 +16,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/andybalholm/brotli"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
@@ -115,6 +116,7 @@ func (this *Node) Start() {
this.checkDisk()
// 读取API配置
remotelogs.Println("NODE", "init config ...")
err = this.syncConfig(0)
if err != nil {
_, err := nodeconfigs.SharedNodeConfig()
@@ -368,6 +370,38 @@ func (this *Node) loop() error {
}
sharedNodeConfig.ParentNodes = parentNodes
// 修改为已同步
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
case "ddosProtectionChanged":
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDOSProtection = nil
}
} else {
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return errors.New("decode DDoS protection config failed: " + err.Error())
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
}
}
// 修改为已同步
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
@@ -396,7 +430,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
clusterErr := this.checkClusterConfig()
if clusterErr != nil {
if os.IsNotExist(clusterErr) {
return err
return errors.New("can not find config file 'configs/api.yaml'")
}
return errors.New("check cluster config failed: " + clusterErr.Error())
}
@@ -426,7 +460,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
return nil
}
configJSON := configResp.NodeJSON
var configJSON = configResp.NodeJSON
if configResp.IsCompressed {
var reader = brotli.NewReader(bytes.NewReader(configJSON))
var configBuf = &bytes.Buffer{}
@@ -445,7 +479,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
nodeConfigUpdatedAt = time.Now().Unix()
nodeConfig := &nodeconfigs.NodeConfig{}
var nodeConfig = &nodeconfigs.NodeConfig{}
err = json.Unmarshal(configJSON, nodeConfig)
if err != nil {
return errors.New("decode config failed: " + err.Error())
@@ -453,6 +487,15 @@ func (this *Node) syncConfig(taskVersion int64) error {
teaconst.NodeId = nodeConfig.Id
teaconst.NodeIdString = types.String(teaconst.NodeId)
// 检查时间是否一致
// 这个需要在 teaconst.NodeId 设置之后因为上报到API节点的时候需要节点ID
if configResp.Timestamp > 0 {
var timestampDelta = configResp.Timestamp - time.Now().Unix()
if timestampDelta > 60 || timestampDelta < -60 {
remotelogs.Error("NODE", "node timestamp ('"+types.String(time.Now().Unix())+"') is not same as api node ('"+types.String(configResp.Timestamp)+"'), please sync the time")
}
}
// 写入到文件中
err = nodeConfig.Save()
if err != nil {
@@ -721,7 +764,6 @@ func (this *Node) listenSock() error {
"ipConns": ipConns,
"serverConns": serverConns,
"total": sharedListenerManager.TotalActiveConnections(),
"limiter": sharedConnectionsLimiter.Len(),
},
})
case "dropIP":
@@ -780,6 +822,18 @@ func (this *Node) listenSock() error {
} else {
_ = cmd.ReplyOk()
}
case "accesslog":
err := sharedHTTPAccessLogViewer.Start()
if err != nil {
_ = cmd.Reply(&gosock.Command{
Code: "error",
Params: map[string]interface{}{
"message": "start failed: " + err.Error(),
},
})
} else {
_ = cmd.ReplyOk()
}
}
})
@@ -813,7 +867,7 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
}
// WAF策略
sharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
waf.SharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
// 统计指标
@@ -845,17 +899,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
this.maxThreads = config.MaxThreads
}
// max tcp connections
if config.TCPMaxConnections <= 0 {
config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
}
if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
sharedConnectionsLimiter.Close()
sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
}
// timezone
var timeZone = config.TimeZone
if len(timeZone) == 0 {
@@ -878,6 +921,29 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
if config.ProductConfig != nil {
teaconst.GlobalProductName = config.ProductConfig.Name
}
// DNS resolver
if config.DNSResolver != nil {
var err error
switch config.DNSResolver.Type {
case nodeconfigs.DNSResolverTypeGoNative:
err = os.Setenv("GODEBUG", "netdns=go")
case nodeconfigs.DNSResolverTypeCGO:
err = os.Setenv("GODEBUG", "netdns=cgo")
default:
// 默认使用go原生
err = os.Setenv("GODEBUG", "netdns=go")
}
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
} else {
// 默认使用go原生
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
}
}
// reload server config

View File

@@ -109,6 +109,7 @@ func (this *NodeStatusExecutor) update() {
cacheSpaceTR.End()
status.UpdatedAt = time.Now().Unix()
status.Timestamp = status.UpdatedAt
// 发送数据
jsonData, err := json.Marshal(status)

View File

@@ -45,6 +45,8 @@ func NewOriginStateManager() *OriginStateManager {
// Start 启动
func (this *OriginStateManager) Start() {
events.OnKey(events.EventReload, this, func() {
// TODO 检查源站是否有变化
this.locker.Lock()
this.stateMap = map[int64]*OriginState{}
this.locker.Unlock()
@@ -143,7 +145,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
state.UpdatedAt = timestamp
if origin.IsOk {
origin.IsOk = state.CountFails > 5 // 超过 N 次之后认为是异常
origin.IsOk = state.CountFails < 5 // 超过 N 次之后认为是异常
if !origin.IsOk {
if callback != nil {

View File

@@ -13,7 +13,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"os"
"os/exec"
@@ -64,7 +63,7 @@ func (this *UpgradeManager) Start() {
goman.New(func() {
err = this.restart()
if err != nil {
logs.Println("UPGRADE_MANAGER", err.Error())
remotelogs.Error("UPGRADE_MANAGER", err.Error())
}
})
}

View File

@@ -67,6 +67,10 @@ func (this *RPCClient) HTTPAccessLogRPC() pb.HTTPAccessLogServiceClient {
return pb.NewHTTPAccessLogServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPCacheTaskKeyRPC() pb.HTTPCacheTaskKeyServiceClient {
return pb.NewHTTPCacheTaskKeyServiceClient(this.pickConn())
}
func (this *RPCClient) APINodeRPC() pb.APINodeServiceClient {
return pb.NewAPINodeServiceClient(this.pickConn())
}
@@ -115,18 +119,6 @@ func (this *RPCClient) ServerRPC() pb.ServerServiceClient {
return pb.NewServerServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRPCList() []pb.ServerServiceClient {
this.locker.Lock()
defer this.locker.Unlock()
var clients = []pb.ServerServiceClient{}
for _, conn := range this.conns {
clients = append(clients, pb.NewServerServiceClient(conn))
}
return clients
}
func (this *RPCClient) ServerDailyStatRPC() pb.ServerDailyStatServiceClient {
return pb.NewServerDailyStatServiceClient(this.pickConn())
}

View File

@@ -80,7 +80,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
remotelogs.Println("TRAFFIC_STAT_MANAGER", "quit")
ticker.Stop()
})
remotelogs.Println("TRAFFIC_STA_MANAGER", "start ...")
remotelogs.Println("TRAFFIC_STAT_MANAGER", "start ...")
for range ticker.C {
err := this.Upload()
if err != nil {

View File

@@ -91,7 +91,7 @@ func (this *Cache) Write(key string, value interface{}, expiredAt int64) (ok boo
})
}
func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64) int64 {
func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64, extend bool) int64 {
if this.isDestroyed {
return 0
}
@@ -107,7 +107,7 @@ func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64) int64
}
uint64Key := HashKey([]byte(key))
pieceIndex := uint64Key % this.countPieces
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt)
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt, extend)
}
func (this *Cache) Read(key string) (item *Item) {

View File

@@ -65,14 +65,14 @@ func TestCache_IncreaseInt64(t *testing.T) {
var unixTime = time.Now().Unix()
{
cache.IncreaseInt64("a", 1, unixTime+3600)
cache.IncreaseInt64("a", 1, unixTime+3600, false)
var item = cache.Read("a")
t.Log(item)
a.IsTrue(item.Value == int64(1))
a.IsTrue(item.expiredAt == unixTime+3600)
}
{
cache.IncreaseInt64("a", 1, unixTime+3600+1)
cache.IncreaseInt64("a", 1, unixTime+3600+1, true)
var item = cache.Read("a")
t.Log(item)
a.IsTrue(item.Value == int64(2))
@@ -83,7 +83,7 @@ func TestCache_IncreaseInt64(t *testing.T) {
t.Log(cache.Read("b"))
}
{
cache.IncreaseInt64("b", 1, time.Now().Unix()+3600+3)
cache.IncreaseInt64("b", 1, time.Now().Unix()+3600+3, false)
t.Log(cache.Read("b"))
}
}

View File

@@ -39,13 +39,15 @@ func (this *Piece) Add(key uint64, item *Item) (ok bool) {
return true
}
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64) (result int64) {
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64, extend bool) (result int64) {
this.locker.Lock()
item, ok := this.m[key]
if ok && item.expiredAt > time.Now().Unix() {
result = types.Int64(item.Value) + delta
item.Value = result
item.expiredAt = expiredAt
if extend {
item.expiredAt = expiredAt
}
this.expiresList.Add(key, expiredAt)
} else {
if len(this.m) < this.maxItems {

View File

@@ -5,9 +5,12 @@ import (
"github.com/cespare/xxhash"
"math"
"net"
"regexp"
"strings"
)
var ipv4Reg = regexp.MustCompile(`\d+\.`)
// IP2Long 将IP转换为整型
// 注意IPv6没有顺序
func IP2Long(ip string) uint64 {
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
return false
}
// IsIPv4 是否为IPv4
func IsIPv4(ip string) bool {
var data = net.ParseIP(ip)
if data == nil {
return false
}
if strings.Contains(ip, ":") {
return false
}
return data.To4() != nil
}
// IsIPv6 是否为IPv6
func IsIPv6(ip string) bool {
var data = net.ParseIP(ip)
if data == nil {
return false
}
return !IsIPv4(ip)
}

View File

@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
a.IsFalse(IsLocalIP("::1:2:3"))
a.IsFalse(IsLocalIP("8.8.8.8"))
}
func TestIsIPv4(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(IsIPv4("192.168.1.1"))
a.IsTrue(IsIPv4("0.0.0.0"))
a.IsFalse(IsIPv4("192.168.1.256"))
a.IsFalse(IsIPv4("192.168.1"))
a.IsFalse(IsIPv4("::1"))
a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
}
func TestIsIPv6(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(IsIPv6("192.168.1.1"))
a.IsFloat32(IsIPv6("0.0.0.0"))
a.IsFalse(IsIPv6("192.168.1.256"))
a.IsFalse(IsIPv6("192.168.1"))
a.IsTrue(IsIPv6("::1"))
a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
}

View File

@@ -2,7 +2,9 @@
package readers
import "io"
import (
"io"
)
type TeeReaderCloser struct {
r io.Reader

View File

@@ -1,6 +1,7 @@
package utils
import (
"sort"
"strings"
"unsafe"
)
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
return result
}
// ToValidUTF8string 去除字符串中的非UTF-8字符
func ToValidUTF8string(v string) string {
return strings.ToValidUTF8(v, "")
}
// ContainsSameStrings 检查两个字符串slice内容是否一致
func ContainsSameStrings(s1 []string, s2 []string) bool {
if len(s1) != len(s2) {
return false
}
sort.Strings(s1)
sort.Strings(s2)
for index, v1 := range s1 {
if v1 != s2[index] {
return false
}
}
return true
}

View File

@@ -1,56 +1,67 @@
package utils
package utils_test
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/assert"
"strings"
"testing"
)
func TestBytesToString(t *testing.T) {
t.Log(UnsafeBytesToString([]byte("Hello,World")))
t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
}
func TestStringToBytes(t *testing.T) {
t.Log(string(UnsafeStringToBytes("Hello,World")))
t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
}
func BenchmarkBytesToString(b *testing.B) {
data := []byte("Hello,World")
var data = []byte("Hello,World")
for i := 0; i < b.N; i++ {
_ = UnsafeBytesToString(data)
_ = utils.UnsafeBytesToString(data)
}
}
func BenchmarkBytesToString2(b *testing.B) {
data := []byte("Hello,World")
var data = []byte("Hello,World")
for i := 0; i < b.N; i++ {
_ = string(data)
}
}
func BenchmarkStringToBytes(b *testing.B) {
s := strings.Repeat("Hello,World", 1024)
var s = strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ {
_ = UnsafeStringToBytes(s)
_ = utils.UnsafeStringToBytes(s)
}
}
func BenchmarkStringToBytes2(b *testing.B) {
s := strings.Repeat("Hello,World", 1024)
var s = strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ {
_ = []byte(s)
}
}
func TestFormatAddress(t *testing.T) {
t.Log(FormatAddress("127.0.0.1:1234"))
t.Log(FormatAddress("127.0.0.1 : 1234"))
t.Log(FormatAddress("127.0.0.11234"))
t.Log(utils.FormatAddress("127.0.0.1:1234"))
t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
t.Log(utils.FormatAddress("127.0.0.11234"))
}
func TestFormatAddressList(t *testing.T) {
t.Log(FormatAddressList([]string{
t.Log(utils.FormatAddressList([]string{
"127.0.0.1:1234",
"127.0.0.1 : 1234",
"127.0.0.11234",
}))
}
func TestContainsSameStrings(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
}

View File

@@ -6,6 +6,7 @@ import (
)
type AllowAction struct {
BaseAction
}
func (this *AllowAction) Init(waf *WAF) error {

View File

@@ -7,6 +7,17 @@ import (
)
type BaseAction struct {
currentActionId int64
}
// ActionId 读取ActionId
func (this *BaseAction) ActionId() int64 {
return this.currentActionId
}
// SetActionId 设置Id
func (this *BaseAction) SetActionId(actionId int64) {
this.currentActionId = actionId
}
// CloseConn 关闭连接

View File

@@ -20,6 +20,8 @@ var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
var httpClient = utils.SharedHttpClient(5 * time.Second)
type BlockAction struct {
BaseAction
StatusCode int `yaml:"statusCode" json:"statusCode"`
Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"`

View File

@@ -18,16 +18,71 @@ const (
)
type CaptchaAction struct {
Life int32 `yaml:"life" json:"life"`
MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数
FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间
BaseAction
Language string `yaml:"language" json:"language"` // 语言zh-CN, en-US ...
Life int32 `yaml:"life" json:"life"`
MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数
FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间
FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"` // 是否全局有效
CountLetters int8 `yaml:"countLetters" json:"countLetters"`
UIIsOn bool `yaml:"uiIsOn" json:"uiIsOn"` // 是否使用自定义UI
UITitle string `yaml:"uiTitle" json:"uiTitle"` // 消息标题
UIPrompt string `yaml:"uiPrompt" json:"uiPrompt"` // 消息提示
UIButtonTitle string `yaml:"uiButtonTitle" json:"uiButtonTitle"` // 按钮标题
UIShowRequestId bool `yaml:"uiShowRequestId" json:"uiShowRequestId"` // 是否显示请求ID
UICss string `yaml:"uiCss" json:"uiCss"` // CSS样式
UIFooter string `yaml:"uiFooter" json:"uiFooter"` // 页脚
UIBody string `yaml:"uiBody" json:"uiBody"` // 内容轮廓
Lang string `yaml:"lang" json:"lang"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
}
func (this *CaptchaAction) Init(waf *WAF) error {
if waf.DefaultCaptchaAction != nil {
if this.Life <= 0 {
this.Life = waf.DefaultCaptchaAction.Life
}
if this.MaxFails <= 0 {
this.MaxFails = waf.DefaultCaptchaAction.MaxFails
}
if this.FailBlockTimeout <= 0 {
this.FailBlockTimeout = waf.DefaultCaptchaAction.FailBlockTimeout
}
this.FailBlockScopeAll = waf.DefaultCaptchaAction.FailBlockScopeAll
if this.CountLetters <= 0 {
this.CountLetters = waf.DefaultCaptchaAction.CountLetters
}
this.UIIsOn = waf.DefaultCaptchaAction.UIIsOn
if len(this.UITitle) == 0 {
this.UITitle = waf.DefaultCaptchaAction.UITitle
}
if len(this.UIPrompt) == 0 {
this.UIPrompt = waf.DefaultCaptchaAction.UIPrompt
}
if len(this.UIButtonTitle) == 0 {
this.UIButtonTitle = waf.DefaultCaptchaAction.UIButtonTitle
}
this.UIShowRequestId = waf.DefaultCaptchaAction.UIShowRequestId
if len(this.UICss) == 0 {
this.UICss = waf.DefaultCaptchaAction.UICss
}
if len(this.UIFooter) == 0 {
this.UIFooter = waf.DefaultCaptchaAction.UIFooter
}
if len(this.UIBody) == 0 {
this.UIBody = waf.DefaultCaptchaAction.UIBody
}
if len(this.Lang) == 0 {
this.Lang = waf.DefaultCaptchaAction.Lang
}
}
return nil
}
@@ -43,17 +98,17 @@ func (this *CaptchaAction) WillChange() bool {
return true
}
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true
}
refURL := request.WAFRaw().URL.String()
var refURL = req.WAFRaw().URL.String()
// 覆盖配置
if strings.HasPrefix(refURL, CaptchaPath) {
info := request.WAFRaw().URL.Query().Get("info")
info := req.WAFRaw().URL.Query().Get("info")
if len(info) > 0 {
m, err := utils.SimpleDecryptMap(info)
if err == nil && m != nil {
@@ -63,14 +118,12 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
}
var captchaConfig = maps.Map{
"action": this,
"timestamp": time.Now().Unix(),
"maxFails": this.MaxFails,
"failBlockTimeout": this.FailBlockTimeout,
"url": refURL,
"policyId": waf.Id,
"groupId": group.Id,
"setId": set.Id,
"actionId": this.ActionId(),
"timestamp": time.Now().Unix(),
"url": refURL,
"policyId": waf.Id,
"groupId": group.Id,
"setId": set.Id,
}
info, err := utils.SimpleEncryptMap(captchaConfig)
if err != nil {
@@ -78,7 +131,10 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
return true
}
http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
// 占用一次失败次数
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false
}

View File

@@ -8,6 +8,8 @@ import (
)
type GoGroupAction struct {
BaseAction
GroupId string `yaml:"groupId" json:"groupId"`
}

View File

@@ -8,6 +8,8 @@ import (
)
type GoSetAction struct {
BaseAction
GroupId string `yaml:"groupId" json:"groupId"`
SetId string `yaml:"setId" json:"setId"`
}

View File

@@ -11,6 +11,12 @@ type ActionInterface interface {
// Init 初始化
Init(waf *WAF) error
// ActionId 读取ActionId
ActionId() int64
// SetActionId 设置ID
SetActionId(id int64)
// Code 代号
Code() string
@@ -20,6 +26,6 @@ type ActionInterface interface {
// WillChange determine if the action will change the request
WillChange() bool
// Perform perform the action
// Perform the action
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
}

View File

@@ -6,6 +6,7 @@ import (
)
type LogAction struct {
BaseAction
}
func (this *LogAction) Init(waf *WAF) error {

View File

@@ -50,6 +50,7 @@ func init() {
}
type NotifyAction struct {
BaseAction
}
func (this *NotifyAction) Init(waf *WAF) error {
@@ -69,7 +70,7 @@ func (this *NotifyAction) WillChange() bool {
return false
}
// Perform perform the action
// Perform the action
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
select {
case notifyChan <- &notifyTask{

View File

@@ -6,6 +6,8 @@ import (
)
type TagAction struct {
BaseAction
Tags []string `yaml:"tags" json:"tags"`
}

View File

@@ -5,15 +5,19 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/maps"
"reflect"
"sync/atomic"
)
var seedActionId int64 = 1
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
for _, def := range AllActions {
if def.Code == action {
if def.Type != nil {
// create new instance
ptrValue := reflect.New(def.Type)
instance := ptrValue.Interface().(ActionInterface)
var ptrValue = reflect.New(def.Type)
var instance = ptrValue.Interface().(ActionInterface)
instance.SetActionId(atomic.AddInt64(&seedActionId, 1))
if len(options) > 0 {
optionsJSON, err := json.Marshal(options)

View File

@@ -0,0 +1,54 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"time"
)
type CaptchaPageCode = string
const (
CaptchaPageCodeInit CaptchaPageCode = "init"
CaptchaPageCodeShow CaptchaPageCode = "show"
CaptchaPageCodeSubmit CaptchaPageCode = "submit"
)
// CaptchaIncreaseFails 增加Captcha失败次数以便后续操作
func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, pageCode CaptchaPageCode) (goNext bool) {
var maxFails = actionConfig.MaxFails
var failBlockTimeout = actionConfig.FailBlockTimeout
if maxFails > 0 && failBlockTimeout > 0 {
if maxFails <= 3 {
maxFails = 3 // 不能小于3防止意外刷新出现
}
var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true)
if int(countFails) >= maxFails {
var useLocalFirewall = false
if actionConfig.FailBlockScopeAll {
useLocalFirewall = true
}
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败")
return false
}
}
return true
}
// CaptchaDeleteCacheKey 清除计数
func CaptchaDeleteCacheKey(req requests.Request) {
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeInit))
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeShow))
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeSubmit))
}
// CaptchaCacheKey 获取Captcha缓存Key
func CaptchaCacheKey(req requests.Request, pageCode CaptchaPageCode) string {
return "CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId())
}

View File

@@ -3,10 +3,7 @@ package waf
import (
"bytes"
"encoding/base64"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/logs"
@@ -26,8 +23,8 @@ func NewCaptchaValidator() *CaptchaValidator {
return &CaptchaValidator{}
}
func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
var info = request.WAFRaw().URL.Query().Get("info")
func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter) {
var info = req.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
@@ -39,42 +36,60 @@ func (this *CaptchaValidator) Run(request requests.Request, writer http.Response
return
}
timestamp := m.GetInt64("timestamp")
var timestamp = m.GetInt64("timestamp")
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
return
}
var actionConfig = &CaptchaAction{}
err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
if err != nil {
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
return
}
var actionId = m.GetInt64("actionId")
var setId = m.GetInt64("setId")
var originURL = m.GetString("url")
var policyId = m.GetInt64("policyId")
var groupId = m.GetInt64("groupId")
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
this.validate(actionConfig, m.GetInt("maxFails"), m.GetInt("failBlockTimeout"), m.GetInt64("policyId"), m.GetInt64("groupId"), setId, originURL, request, writer)
var waf = SharedWAFManager.FindWAF(policyId)
if waf == nil {
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
var actionConfig = waf.FindAction(actionId)
if actionConfig == nil {
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
captchaActionConfig, ok := actionConfig.(*CaptchaAction)
if !ok {
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
return
}
if req.WAFRaw().Method == http.MethodPost && len(req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
this.validate(captchaActionConfig, policyId, groupId, setId, originURL, req, writer)
} else {
this.show(actionConfig, request, writer)
// 增加计数
CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeShow)
this.show(captchaActionConfig, req, writer)
}
}
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Request, writer http.ResponseWriter) {
// show captcha
captchaId := captcha.NewLen(6)
buf := bytes.NewBuffer([]byte{})
var countLetters = 6
if actionConfig.CountLetters > 0 && actionConfig.CountLetters <= 10 {
countLetters = int(actionConfig.CountLetters)
}
var captchaId = captcha.NewLen(countLetters)
var buf = bytes.NewBuffer([]byte{})
err := captcha.WriteImage(buf, captchaId, 200, 100)
if err != nil {
logs.Error(err)
return
}
var lang = actionConfig.Language
var lang = actionConfig.Lang
if len(lang) == 0 {
acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
var acceptLanguage = req.WAFRaw().Header.Get("Accept-Language")
if len(acceptLanguage) > 0 {
langIndex := strings.Index(acceptLanguage, ",")
if langIndex > 0 {
@@ -109,12 +124,67 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
msgRequestId = "Request ID"
}
var msgCss = ""
var requestIdBox = `<address>` + msgRequestId + `: ` + req.Format("${requestId}") + `</address>`
var msgFooter = ""
// 默认设置
if actionConfig.UIIsOn {
if len(actionConfig.UIPrompt) > 0 {
msgPrompt = actionConfig.UIPrompt
}
if len(actionConfig.UIButtonTitle) > 0 {
msgButtonTitle = actionConfig.UIButtonTitle
}
if len(actionConfig.UITitle) > 0 {
msgTitle = actionConfig.UITitle
}
if len(actionConfig.UICss) > 0 {
msgCss = actionConfig.UICss
}
if !actionConfig.UIShowRequestId {
requestIdBox = ""
}
if len(actionConfig.UIFooter) > 0 {
msgFooter = actionConfig.UIFooter
}
}
var body = `<form method="POST">
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
<div class="ui-image">
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
</div>
<div class="ui-input">
<p>` + msgPrompt + `</p>
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" class="input"/>
</div>
<div class="ui-button">
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
</div>
</form>
` + requestIdBox + `
` + msgFooter + ``
// Body
if actionConfig.UIIsOn {
if len(actionConfig.UIBody) > 0 {
var index = strings.Index(actionConfig.UIBody, "${body}")
if index < 0 {
body = actionConfig.UIBody + body
} else {
body = actionConfig.UIBody[:index] + body + actionConfig.UIBody[index+7:] // 7是"${body}"的长度
}
}
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = writer.Write([]byte(`<!DOCTYPE html>
<html>
<head>
<title>` + msgTitle + `</title>
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
<meta charset="UTF-8"/>
<script type="text/javascript">
if (window.addEventListener != null) {
window.addEventListener("load", function () {
@@ -126,32 +196,22 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
form { width: 20em; margin: 0 auto; text-align: center; }
.input { font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 10px; width: 140px; }
address { margin-top: 1em; padding-top: 0.5em; border-top: 1px #ccc solid; text-align: center; }
` + msgCss + `
</style>
</head>
<body>
<form method="POST">
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
<div>
<p>` + msgPrompt + `</p>
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" class="input"/>
</div>
<div>
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
</div>
</form>
<address>` + msgRequestId + `: ` + request.Format("${requestId}") + `</address>
<body>` + body + `
</body>
</html>`))
}
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, maxFails int, failBlockTimeout int, policyId int64, groupId int64, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) {
var captchaId = req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
if len(captchaId) > 0 {
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
var captchaCode = req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
if captcha.VerifyString(captchaId, captchaCode) {
// 除计数
ttlcache.SharedCache.Delete("CAPTCHA:FAILS:" + request.WAFRemoteIP())
// 除计数
CaptchaDeleteCacheKey(req)
var life = CaptchaSeconds
if actionConfig.Life > 0 {
@@ -159,22 +219,18 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, maxFails int
}
// 加入到白名单
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "")
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "")
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusSeeOther)
return false
} else {
// 增加计数
if maxFails > 0 && failBlockTimeout > 0 {
var countFails = ttlcache.SharedCache.IncreaseInt64("CAPTCHA:FAILS:"+request.WAFRemoteIP(), 1, time.Now().Unix()+300)
if int(countFails) >= maxFails {
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, false, groupId, setId, "CAPTCHA验证连续失败")
return false
}
if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) {
return false
}
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
http.Redirect(writer, req.WAFRaw(), req.WAFRaw().URL.String(), http.StatusSeeOther)
}
}

View File

@@ -114,7 +114,7 @@ func (this *CCCheckpoint) RequestValue(req requests.Request, param string, optio
if len(key) == 0 {
key = req.WAFRemoteIP()
}
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period, false)
}
return

View File

@@ -38,7 +38,7 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
threshold = 1000
}
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period)
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period, false)
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
)
type RequestURLCheckpoint struct {
Checkpoint
}
func (this *RequestURLCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return req.Format("${requestURL}"), nil, nil
}
func (this *RequestURLCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -40,17 +40,24 @@ var AllCheckpoints = []*CheckpointDefinition{
{
Name: "请求URI",
Prefix: "requestURI",
Description: "包含URL参数的请求URI比如/hello/world?lang=go",
Description: "包含URL参数的请求URI类似于 /hello/world?lang=go",
HasParams: false,
Instance: new(RequestURICheckpoint),
},
{
Name: "请求路径",
Prefix: "requestPath",
Description: "不包含URL参数的请求路径比如/hello/world",
Description: "不包含URL参数的请求路径类似于 /hello/world",
HasParams: false,
Instance: new(RequestPathCheckpoint),
},
{
Name: "请求URL",
Prefix: "requestURL",
Description: "完整的请求URL包含协议、域名、请求路径、参数等类似于 https://example.com/hello?name=lily",
HasParams: false,
Instance: new(RequestURLCheckpoint),
},
{
Name: "请求内容长度",
Prefix: "requestLength",

View File

@@ -130,8 +130,8 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
}
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.ipMap[ip]
this.locker.RUnlock()
return ok
}

View File

@@ -66,7 +66,7 @@ func (this *RuleSet) Init(waf *WAF) error {
// action instances
this.actionInstances = []ActionInterface{}
for _, action := range this.Actions {
instance := FindActionInstance(action.Code, action.Options)
var instance = FindActionInstance(action.Code, action.Options)
if instance == nil {
remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'")
continue
@@ -79,6 +79,7 @@ func (this *RuleSet) Init(waf *WAF) error {
}
this.actionInstances = append(this.actionInstances, instance)
waf.AddAction(instance)
}
// sort actions

View File

@@ -26,12 +26,14 @@ type WAF struct {
UseLocalFirewall bool `yaml:"useLocalFirewall" json:"useLocalFirewall"`
SYNFlood *firewallconfigs.SYNFloodConfig `yaml:"synFlood" json:"synFlood"`
DefaultBlockAction *BlockAction
DefaultBlockAction *BlockAction
DefaultCaptchaAction *CaptchaAction
hasInboundRules bool
hasOutboundRules bool
checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
actionMap map[int64]ActionInterface // actionId => ActionInterface
}
func NewWAF() *WAF {
@@ -74,6 +76,9 @@ func (this *WAF) Init() (resultErrors []error) {
this.checkpointsMap[def.Prefix] = instance
}
// action map
this.actionMap = map[int64]ActionInterface{}
// rules
this.hasInboundRules = len(this.Inbound) > 0
this.hasOutboundRules = len(this.Outbound) > 0
@@ -324,8 +329,16 @@ func (this *WAF) ContainsGroupCode(code string) bool {
return false
}
func (this *WAF) AddAction(action ActionInterface) {
this.actionMap[action.ActionId()] = action
}
func (this *WAF) FindAction(actionId int64) ActionInterface {
return this.actionMap[actionId]
}
func (this *WAF) Copy() *WAF {
waf := &WAF{
var waf = &WAF{
Id: this.Id,
IsOn: this.IsOn,
Name: this.Name,

Some files were not shown because too many files have changed in this diff Show More