Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77bb1cf14e | ||
|
|
274284dbe1 | ||
|
|
cd6d7221e8 | ||
|
|
b1d0c8852e | ||
|
|
4c4033bb56 | ||
|
|
6d642b75f6 | ||
|
|
eb47e3a08c | ||
|
|
d82d16e28d | ||
|
|
b2fc785543 | ||
|
|
189e3342ce | ||
|
|
885defbf31 | ||
|
|
74f1bf330d | ||
|
|
ad843d9d10 | ||
|
|
13e718742d | ||
|
|
771eff8fb1 | ||
|
|
20d7e0b1bf | ||
|
|
e6c7bbec06 | ||
|
|
be61ef89fe | ||
|
|
3d7d8f1e63 | ||
|
|
a4fb465a19 | ||
|
|
96c725c13d | ||
|
|
7635def2fa | ||
|
|
b704a73338 | ||
|
|
123b5f5969 | ||
|
|
eea2037444 | ||
|
|
4e6d2fa5ea | ||
|
|
14bb131e8d | ||
|
|
31814bb54c | ||
|
|
49b8fd6e97 | ||
|
|
a9d31a2e35 | ||
|
|
298cef7f05 | ||
|
|
9bdd9a433c | ||
|
|
45620dcdb7 | ||
|
|
84a5d38b0b | ||
|
|
e812b3fcf6 | ||
|
|
1bd16fa1d3 | ||
|
|
f3ea4957be | ||
|
|
04da107c94 | ||
|
|
e77de69a15 | ||
|
|
e88eda56f5 | ||
|
|
d6ceccc52e | ||
|
|
cd948ac68c | ||
|
|
9eac8afa3d | ||
|
|
fb3610966a | ||
|
|
a6673449db |
29
LICENSE
Normal file
29
LICENSE
Normal file
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2020, LiuXiangChao
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -22,7 +22,7 @@ func main() {
|
||||
app := apps.NewAppCmd().
|
||||
Version(teaconst.Version).
|
||||
Product(teaconst.ProductName).
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof]").
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|status|quit|test|reload|service|daemon|pprof|accesslog]").
|
||||
Usage(teaconst.ProcessName + " [trackers|goman|conns|gc]").
|
||||
Usage(teaconst.ProcessName + " [ip.drop|ip.reject|ip.remove] IP")
|
||||
|
||||
@@ -258,6 +258,53 @@ func main() {
|
||||
}
|
||||
}
|
||||
})
|
||||
app.On("accesslog", func() {
|
||||
// local sock
|
||||
var tmpDir = os.TempDir()
|
||||
var sockFile = tmpDir + "/" + teaconst.AccessLogSockName
|
||||
_, err := os.Stat(sockFile)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var processSock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
reply, err := processSock.Send(&gosock.Command{
|
||||
Code: "accesslog",
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
if reply.Code == "error" {
|
||||
var errString = maps.NewMap(reply.Params).GetString("error")
|
||||
if len(errString) > 0 {
|
||||
fmt.Println("[ERROR]" + errString)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("unix", sockFile)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]start reading access log failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
var buf = make([]byte, 1024)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if n > 0 {
|
||||
fmt.Print(string(buf[:n]))
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
app.Run(func() {
|
||||
node := nodes.NewNode()
|
||||
node.Start()
|
||||
|
||||
5
go.mod
5
go.mod
@@ -2,7 +2,9 @@ module github.com/TeaOSLab/EdgeNode
|
||||
|
||||
go 1.15
|
||||
|
||||
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
|
||||
replace (
|
||||
github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
|
||||
@@ -20,6 +22,7 @@ require (
|
||||
github.com/iwind/gosock v0.0.0-20211103081026-ee4652210ca4
|
||||
github.com/iwind/gowebp v0.0.0-20211029040624-7331ecc78ed8
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
|
||||
github.com/klauspost/compress v1.15.2 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.9
|
||||
github.com/miekg/dns v1.1.43
|
||||
|
||||
2
go.sum
2
go.sum
@@ -126,6 +126,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4h
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e h1:LvL4XsI70QxOGHed6yhQtAU34Kx3Qq2wwBzGFKY8zKk=
|
||||
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
|
||||
github.com/klauspost/compress v1.15.2 h1:3WH+AG7s2+T8o3nrM/8u2rdqUEcQhmga7smjrT41nAw=
|
||||
github.com/klauspost/compress v1.15.2/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
|
||||
@@ -217,7 +217,10 @@ func (this *AppCmd) runStop() {
|
||||
if runtime.GOOS == "linux" {
|
||||
systemctl, _ := exec.LookPath("systemctl")
|
||||
if len(systemctl) > 0 {
|
||||
_ = exec.Command(systemctl, "stop", teaconst.SystemdServiceName).Run()
|
||||
go func() {
|
||||
// 有可能会长时间执行,这里不阻塞进程
|
||||
_ = exec.Command(systemctl, "stop", teaconst.SystemdServiceName).Run()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -214,3 +214,15 @@ func (this *Manager) FindAllCachePaths() []string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FindAllStorages 读取所有缓存存储
|
||||
func (this *Manager) FindAllStorages() []StorageInterface {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var storages = []StorageInterface{}
|
||||
for _, storage := range this.storageMap {
|
||||
storages = append(storages, storage)
|
||||
}
|
||||
return storages
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
minOpenFilesValue int32 = 2
|
||||
minOpenFilesValue int32 = 4
|
||||
maxOpenFilesValue int32 = 65535
|
||||
|
||||
modeSlow int32 = 1
|
||||
@@ -35,7 +35,7 @@ func NewMaxOpenFiles(step int32) *MaxOpenFiles {
|
||||
}
|
||||
var f = &MaxOpenFiles{
|
||||
step: step,
|
||||
maxOpenFiles: 2,
|
||||
maxOpenFiles: minOpenFilesValue,
|
||||
}
|
||||
if teaconst.DiskIsFast {
|
||||
f.maxOpenFiles = 32
|
||||
@@ -68,7 +68,7 @@ func (this *MaxOpenFiles) init() {
|
||||
atomic.StoreInt32(&this.currentOpens, 0)
|
||||
}
|
||||
|
||||
// reset mod
|
||||
// reset mode
|
||||
atomic.StoreInt32(&this.mode, 0)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -428,7 +428,7 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
|
||||
return nil, ErrFileIsWriting
|
||||
}
|
||||
|
||||
if len(sharedWritingFileKeyMap) >= int(maxOpenFiles.Max()) {
|
||||
if !isFlushing && len(sharedWritingFileKeyMap) >= int(maxOpenFiles.Max()) {
|
||||
sharedWritingFileKeyLocker.Unlock()
|
||||
return nil, ErrTooManyOpenFiles
|
||||
}
|
||||
@@ -481,11 +481,16 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
|
||||
openFileCache.Close(cachePath)
|
||||
}
|
||||
|
||||
// 查询当前已有缓存文件
|
||||
stat, err := os.Stat(cachePath)
|
||||
if err == nil && time.Now().Sub(stat.ModTime()) <= 1*time.Second {
|
||||
|
||||
// 检查两次写入缓存的时间是否过于相近,分片内容不受此限制
|
||||
if err == nil && !isPartial && time.Now().Sub(stat.ModTime()) <= 1*time.Second {
|
||||
// 防止并发连续写入
|
||||
return nil, ErrFileIsWriting
|
||||
}
|
||||
|
||||
// 构造文件名
|
||||
var tmpPath = cachePath
|
||||
var existsFile = false
|
||||
if stat != nil {
|
||||
@@ -534,10 +539,12 @@ func (this *FileStorage) openWriter(key string, expiredAt int64, status int, siz
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if time.Since(before) >= maxOpenFilesSlowCost {
|
||||
maxOpenFiles.Slow()
|
||||
} else {
|
||||
maxOpenFiles.Fast()
|
||||
if !isFlushing {
|
||||
if time.Since(before) >= maxOpenFilesSlowCost {
|
||||
maxOpenFiles.Slow()
|
||||
} else {
|
||||
maxOpenFiles.Fast()
|
||||
}
|
||||
}
|
||||
|
||||
var removeOnFailure = true
|
||||
@@ -767,9 +774,10 @@ func (this *FileStorage) Purge(keys []string, urlType string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 文件
|
||||
// URL
|
||||
for _, key := range keys {
|
||||
hash, path := this.keyPath(key)
|
||||
err := this.removeCacheFile(path)
|
||||
|
||||
@@ -35,6 +35,7 @@ type StorageInterface interface {
|
||||
CleanAll() error
|
||||
|
||||
// Purge 批量删除缓存
|
||||
// urlType 值为file|dir
|
||||
Purge(keys []string, urlType string) error
|
||||
|
||||
// Stop 停止缓存策略
|
||||
|
||||
@@ -267,8 +267,10 @@ func (this *MemoryStorage) Purge(keys []string, urlType string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// URL
|
||||
for _, key := range keys {
|
||||
err := this.Delete(key)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"github.com/klauspost/compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
package compressions
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"github.com/klauspost/compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
||||
@@ -34,3 +34,31 @@ func BenchmarkGzipWriter_Write(b *testing.B) {
|
||||
_ = writer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGzipWriter_Write_Parallel(b *testing.B) {
|
||||
var data = []byte(strings.Repeat("A", 1024))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var buf = &bytes.Buffer{}
|
||||
writer, err := compressions.NewGzipWriter(buf, 5)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for j := 0; j < 100; j++ {
|
||||
_, err = writer.Write(data)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
/**err = writer.Flush()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}**/
|
||||
}
|
||||
|
||||
_ = writer.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build community
|
||||
// +build community
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package teaconst
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "0.4.7"
|
||||
Version = "0.4.8.1"
|
||||
|
||||
ProductName = "Edge Node"
|
||||
ProcessName = "edge-node"
|
||||
@@ -13,4 +13,6 @@ const (
|
||||
|
||||
// SystemdServiceName systemd
|
||||
SystemdServiceName = "edge-node"
|
||||
|
||||
AccessLogSockName = "edge-node.accesslog.sock"
|
||||
)
|
||||
|
||||
@@ -3,9 +3,10 @@ package events
|
||||
type Event = string
|
||||
|
||||
const (
|
||||
EventStart Event = "start" // start loading
|
||||
EventLoaded Event = "loaded" // first load
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventReload Event = "reload" // reload config
|
||||
EventTerminated Event = "terminated" // process terminated
|
||||
EventStart Event = "start" // start loading
|
||||
EventLoaded Event = "loaded" // first load
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventReload Event = "reload" // reload config
|
||||
EventTerminated Event = "terminated" // process terminated
|
||||
EventNFTablesReady Event = "nftablesReady" // nftables ready
|
||||
)
|
||||
|
||||
1
internal/firewalls/.gitignore
vendored
1
internal/firewalls/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
firewall_nftables_test.go
|
||||
502
internal/firewalls/ddos_protection.go
Normal file
502
internal/firewalls/ddos_protection.go
Normal file
@@ -0,0 +1,502 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
if nftablesInstance == nil {
|
||||
return
|
||||
}
|
||||
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
events.On(events.EventNFTablesReady, func() {
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// DDoSProtectionManager DDoS防护
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
|
||||
lastAllowIPList []string
|
||||
lastConfig []byte
|
||||
}
|
||||
|
||||
// NewDDoSProtectionManager 获取新对象
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
nftPath, _ := exec.LookPath("nft")
|
||||
|
||||
return &DDoSProtectionManager{
|
||||
nftPath: nftPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply 应用配置
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
// 同集群节点IP白名单
|
||||
var allowIPListChanged = false
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
var allowIPList = nodeConfig.AllowedIPs
|
||||
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
|
||||
allowIPListChanged = true
|
||||
this.lastAllowIPList = allowIPList
|
||||
}
|
||||
}
|
||||
|
||||
// 对比配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return errors.New("encode config to json failed: " + err.Error())
|
||||
}
|
||||
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
|
||||
return nil
|
||||
}
|
||||
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||
|
||||
if len(this.nftPath) == 0 {
|
||||
return errors.New("can not find nft command")
|
||||
}
|
||||
|
||||
if nftablesInstance == nil {
|
||||
return errors.New("nftables instance should not be nil")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
// TCP
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO other protocols
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TCP
|
||||
if config.TCP == nil {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// allow ip list
|
||||
var allowIPList = []string{}
|
||||
for _, ipConfig := range config.TCP.AllowIPList {
|
||||
allowIPList = append(allowIPList, ipConfig.IP)
|
||||
}
|
||||
for _, ip := range this.lastAllowIPList {
|
||||
if !lists.ContainsString(allowIPList, ip) {
|
||||
allowIPList = append(allowIPList, ip)
|
||||
}
|
||||
}
|
||||
err = this.updateAllowIPList(allowIPList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// tcp
|
||||
if config.TCP.IsOn {
|
||||
err := this.addTCPRules(config.TCP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.lastConfig = configJSON
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加TCP规则
|
||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||
// 检查nft版本不能小于0.9
|
||||
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ports = []int32{}
|
||||
for _, portConfig := range tcpConfig.Ports {
|
||||
if !lists.ContainsInt32(ports, portConfig.Port) {
|
||||
ports = append(ports, portConfig.Port)
|
||||
}
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
ports = []int32{80, 443}
|
||||
}
|
||||
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, oldRules, err := this.getRules(filter)
|
||||
if err != nil {
|
||||
return errors.New("get old rules failed: " + err.Error())
|
||||
}
|
||||
|
||||
var protocol = filter.protocol()
|
||||
|
||||
// max connections
|
||||
var maxConnections = tcpConfig.MaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = nodeconfigs.DefaultTCPMaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// max connections per ip
|
||||
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// new connections rate
|
||||
var newConnectionsRate = tcpConfig.NewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
var hasChanges = false
|
||||
for _, port := range ports {
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
// 检查是否有多余的端口
|
||||
var oldPorts = this.getTCPPorts(oldRules)
|
||||
if !this.eqPorts(ports, oldPorts) {
|
||||
hasChanges = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先清空所有相关规则
|
||||
err = this.removeOldTCPRules(chain, oldRules)
|
||||
if err != nil {
|
||||
return errors.New("delete old rules failed: " + err.Error())
|
||||
}
|
||||
|
||||
// 添加新规则
|
||||
for _, port := range ports {
|
||||
if maxConnections > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
}
|
||||
}
|
||||
|
||||
if maxConnectionsPerIP > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
}
|
||||
}
|
||||
|
||||
if newConnectionsRate > 0 {
|
||||
// TODO 思考是否有惩罚机制
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除TCP规则
|
||||
func (this *DDoSProtectionManager) removeTCPRules() error {
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, rules, err := this.getRules(filter)
|
||||
|
||||
// TCP
|
||||
err = this.removeOldTCPRules(chain, rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 组合user data
|
||||
// 数据中不能包含字母、数字、下划线以外的数据
|
||||
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
|
||||
if attrs == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
|
||||
}
|
||||
|
||||
// 解码user data
|
||||
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dataCopy = make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
var separatorLen = 2
|
||||
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
|
||||
if index1 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dataCopy = dataCopy[index1+separatorLen:]
|
||||
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
|
||||
if index2 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var s = string(dataCopy[:index2])
|
||||
var pieces = strings.Split(s, "_")
|
||||
for index, piece := range pieces {
|
||||
pieces[index] = strings.TrimSpace(piece)
|
||||
}
|
||||
return pieces
|
||||
}
|
||||
|
||||
// 清除规则
|
||||
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
switch pieces[2] {
|
||||
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
|
||||
err := chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据参数检查规则是否存在
|
||||
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
|
||||
if len(attrs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, oldRule := range rules {
|
||||
var pieces = this.decodeUserData(oldRule.UserData())
|
||||
if len(attrs) != len(pieces) {
|
||||
continue
|
||||
}
|
||||
var isSame = true
|
||||
for index, piece := range pieces {
|
||||
if strings.TrimSpace(piece) != attrs[index] {
|
||||
isSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isSame {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取规则中的端口号
|
||||
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
|
||||
var ports = []int32{}
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
var port = types.Int32(pieces[1])
|
||||
if port > 0 && !lists.ContainsInt32(ports, port) {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
// 检查端口是否一样
|
||||
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
|
||||
if len(ports1) != len(ports2) {
|
||||
return false
|
||||
}
|
||||
|
||||
var portMap = map[int32]bool{}
|
||||
for _, port := range ports2 {
|
||||
portMap[port] = true
|
||||
}
|
||||
|
||||
for _, port := range ports1 {
|
||||
_, ok := portMap[port]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 查找Table
|
||||
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
|
||||
var family nftables.TableFamily
|
||||
if filter.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if filter.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
|
||||
}
|
||||
return nftablesInstance.conn.GetTable(filter.Name, family)
|
||||
}
|
||||
|
||||
// 查找所有规则
|
||||
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
|
||||
table, err := this.getTable(filter)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("get table failed: " + err.Error())
|
||||
}
|
||||
chain, err := table.GetChain(nftablesChainName)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("get chain failed: " + err.Error())
|
||||
}
|
||||
rules, err := chain.GetRules()
|
||||
return chain, rules, err
|
||||
}
|
||||
|
||||
// 更新白名单
|
||||
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
if nftablesInstance == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allMap = map[string]zero.Zero{}
|
||||
for _, ip := range allIPList {
|
||||
allMap[ip] = zero.New()
|
||||
}
|
||||
|
||||
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
|
||||
var isIPv4 = set == nftablesInstance.allowIPv4Set
|
||||
var isIPv6 = !isIPv4
|
||||
|
||||
// 现有的
|
||||
oldList, err := set.GetIPElements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var oldMap = map[string]zero.Zero{} // ip=> zero
|
||||
for _, ip := range oldList {
|
||||
oldMap[ip] = zero.New()
|
||||
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := allMap[ip]
|
||||
if !ok {
|
||||
// 不存在则删除
|
||||
err = set.DeleteIPElement(ip)
|
||||
if err != nil {
|
||||
return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 新增的
|
||||
for _, ip := range allIPList {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
continue
|
||||
}
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := oldMap[ip]
|
||||
if !ok {
|
||||
// 不存在则添加
|
||||
err = set.AddIPElement(ip, nil)
|
||||
if err != nil {
|
||||
return errors.New("add ip '" + ip + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
23
internal/firewalls/ddos_protection_others.go
Normal file
23
internal/firewalls/ddos_protection_others.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
}
|
||||
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
return &DDoSProtectionManager{}
|
||||
}
|
||||
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package firewalls
|
||||
|
||||
@@ -8,9 +6,11 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var currentFirewall FirewallInterface
|
||||
var firewallLocker = &sync.Mutex{}
|
||||
|
||||
// 初始化
|
||||
func init() {
|
||||
@@ -24,10 +24,28 @@ func init() {
|
||||
|
||||
// Firewall 查找当前系统中最适合的防火墙
|
||||
func Firewall() FirewallInterface {
|
||||
firewallLocker.Lock()
|
||||
defer firewallLocker.Unlock()
|
||||
if currentFirewall != nil {
|
||||
return currentFirewall
|
||||
}
|
||||
|
||||
// nftables
|
||||
if runtime.GOOS == "linux" {
|
||||
nftables, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
|
||||
} else {
|
||||
if nftables.IsReady() {
|
||||
currentFirewall = nftables
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
return nftables
|
||||
} else {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// firewalld
|
||||
if runtime.GOOS == "linux" {
|
||||
var firewalld = NewFirewalld()
|
||||
|
||||
@@ -23,12 +23,13 @@ func NewFirewalld() *Firewalld {
|
||||
|
||||
path, err := exec.LookPath("firewall-cmd")
|
||||
if err == nil && len(path) > 0 {
|
||||
var cmd = exec.Command(path, "-V")
|
||||
var cmd = exec.Command(path, "--state")
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
firewalld.exe = path
|
||||
// TODO check firewalld status with 'firewall-cmd --state' (running or not running),
|
||||
// but we should recover the state when firewalld state changes, maybe check it every minutes
|
||||
|
||||
firewalld.isReady = true
|
||||
firewalld.init()
|
||||
}
|
||||
|
||||
395
internal/firewalls/firewall_nftables.go
Normal file
395
internal/firewalls/firewall_nftables.go
Normal file
@@ -0,0 +1,395 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// check nft status, if being enabled we load it automatically
|
||||
func init() {
|
||||
if runtime.GOOS == "linux" {
|
||||
var ticker = time.NewTicker(3 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// if already ready, we break
|
||||
if nftablesIsReady {
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
nftablesFirewall, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentFirewall = nftablesFirewall
|
||||
remotelogs.Println("FIREWALL", "nftables is ready")
|
||||
|
||||
// fire event
|
||||
if nftablesFirewall.IsReady() {
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var nftablesInstance *NFTablesFirewall
|
||||
var nftablesIsReady = false
|
||||
var nftablesFilters = []*nftablesTableDefinition{
|
||||
// we shorten the name for table name length restriction
|
||||
{Name: "edge_dft_v4", IsIPv4: true},
|
||||
{Name: "edge_dft_v6", IsIPv6: true},
|
||||
}
|
||||
var nftablesChainName = "input"
|
||||
|
||||
type nftablesTableDefinition struct {
|
||||
Name string
|
||||
IsIPv4 bool
|
||||
IsIPv6 bool
|
||||
}
|
||||
|
||||
func (this *nftablesTableDefinition) protocol() string {
|
||||
if this.IsIPv6 {
|
||||
return "ip6"
|
||||
}
|
||||
return "ip"
|
||||
}
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
var firewall = &NFTablesFirewall{
|
||||
conn: nftables.NewConn(),
|
||||
}
|
||||
err := firewall.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return firewall, nil
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
conn *nftables.Conn
|
||||
isReady bool
|
||||
version string
|
||||
|
||||
allowIPv4Set *nftables.Set
|
||||
allowIPv6Set *nftables.Set
|
||||
|
||||
denyIPv4Set *nftables.Set
|
||||
denyIPv6Set *nftables.Set
|
||||
|
||||
firewalld *Firewalld
|
||||
}
|
||||
|
||||
func (this *NFTablesFirewall) init() error {
|
||||
// check nft
|
||||
nftPath, err := exec.LookPath("nft")
|
||||
if err != nil {
|
||||
return errors.New("nft not found")
|
||||
}
|
||||
this.version = this.readVersion(nftPath)
|
||||
|
||||
// table
|
||||
for _, tableDef := range nftablesFilters {
|
||||
var family nftables.TableFamily
|
||||
if tableDef.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if tableDef.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return errors.New("invalid table family: " + types.String(tableDef))
|
||||
}
|
||||
table, err := this.conn.GetTable(tableDef.Name, family)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
table, err = this.conn.AddIPv4Table(tableDef.Name)
|
||||
} else if tableDef.IsIPv6 {
|
||||
table, err = this.conn.AddIPv6Table(tableDef.Name)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if table == nil {
|
||||
return errors.New("can not create table '" + tableDef.Name + "'")
|
||||
}
|
||||
|
||||
// chain
|
||||
var chainName = nftablesChainName
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
chain, err = table.AddAcceptChain(chainName)
|
||||
if err != nil {
|
||||
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if chain == nil {
|
||||
return errors.New("can not create chain '" + chainName + "'")
|
||||
}
|
||||
|
||||
// allow lo
|
||||
var loRuleName = []byte("lo")
|
||||
_, err = chain.GetRuleWithUserData(loRuleName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("add 'lo' rule failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// allow set
|
||||
// "allow" should be always first
|
||||
for _, setAction := range []string{"allow", "deny"} {
|
||||
var setName = setAction + "_set"
|
||||
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
var keyType nftables.SetDataType
|
||||
if tableDef.IsIPv4 {
|
||||
keyType = nftables.TypeIPAddr
|
||||
} else if tableDef.IsIPv6 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
set, err = table.AddSet(setName, &nftables.SetOptions{
|
||||
KeyType: keyType,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("create set '" + setName + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get set '" + setName + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if set == nil {
|
||||
return errors.New("can not create set '" + setName + "'")
|
||||
}
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv4Set = set
|
||||
} else {
|
||||
this.denyIPv4Set = set
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv6Set = set
|
||||
} else {
|
||||
this.denyIPv6Set = set
|
||||
}
|
||||
}
|
||||
|
||||
// rule
|
||||
var ruleName = []byte(setAction)
|
||||
rule, err := chain.GetRuleWithUserData(ruleName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("add rule failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get rule failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if rule == nil {
|
||||
return errors.New("can not create rule '" + string(ruleName) + "'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.isReady = true
|
||||
nftablesIsReady = true
|
||||
nftablesInstance = this
|
||||
|
||||
// load firewalld
|
||||
var firewalld = NewFirewalld()
|
||||
if firewalld.IsReady() {
|
||||
this.firewalld = firewalld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return this.isReady
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.AllowPort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.RemovePort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.allowIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return this.DropSourceIP(ip, timeoutSeconds)
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.denyIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set != nil {
|
||||
err := this.denyIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if this.allowIPv6Set != nil {
|
||||
err := this.allowIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set != nil {
|
||||
err := this.denyIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.allowIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 读取版本号
|
||||
func (this *NFTablesFirewall) readVersion(nftPath string) string {
|
||||
var cmd = exec.Command(nftPath, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
return ""
|
||||
}
|
||||
return versionMatches[1]
|
||||
}
|
||||
61
internal/firewalls/firewall_nftables_others.go
Normal file
61
internal/firewalls/firewall_nftables_others.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
build_remote.sh
|
||||
370
internal/firewalls/nftables/chain.go
Normal file
370
internal/firewalls/nftables/chain.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
const MaxChainNameLength = 31
|
||||
|
||||
type RuleOptions struct {
|
||||
Exprs []expr.Any
|
||||
UserData []byte
|
||||
}
|
||||
|
||||
// Chain chain object in table
|
||||
type Chain struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
rawChain *nft.Chain
|
||||
}
|
||||
|
||||
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
|
||||
return &Chain{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
rawChain: rawChain,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Chain) Raw() *nft.Chain {
|
||||
return this.rawChain
|
||||
}
|
||||
|
||||
func (this *Chain) Name() string {
|
||||
return this.rawChain.Name
|
||||
}
|
||||
|
||||
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
|
||||
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
|
||||
Table: this.rawTable,
|
||||
Chain: this.rawChain,
|
||||
Exprs: options.Exprs,
|
||||
UserData: options.UserData,
|
||||
})
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
|
||||
if len(interfaceName) >= 16 {
|
||||
return nil, errors.New("invalid interface name '" + interfaceName + "'")
|
||||
}
|
||||
var ifname = make([]byte, 16)
|
||||
copy(ifname, interfaceName+"\x00")
|
||||
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawRule := range rawRules {
|
||||
if bytes.Compare(rawRule.UserData, userData) == 0 {
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrRuleNotFound
|
||||
}
|
||||
|
||||
func (this *Chain) GetRules() ([]*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result = []*Rule{}
|
||||
for _, rawRule := range rawRules {
|
||||
result = append(result, NewRule(rawRule))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (this *Chain) DeleteRule(rule *Rule) error {
|
||||
err := this.conn.Raw().DelRule(rule.Raw())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Chain) Flush() error {
|
||||
this.conn.Raw().FlushChain(this.rawChain)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
13
internal/firewalls/nftables/chain_policy.go
Normal file
13
internal/firewalls/nftables/chain_policy.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type ChainPolicy = nft.ChainPolicy
|
||||
|
||||
// Possible ChainPolicy values.
|
||||
const (
|
||||
ChainPolicyDrop = nft.ChainPolicyDrop
|
||||
ChainPolicyAccept = nft.ChainPolicyAccept
|
||||
)
|
||||
130
internal/firewalls/nftables/chain_test.go
Normal file
130
internal/firewalls/nftables/chain_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
chain, err := table.GetChain("test_chain")
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
chain, err = table.AddAcceptChain("test_chain")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddRejectSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_GetRuleWithUserData(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("rule:", rule)
|
||||
}
|
||||
|
||||
func TestChain_GetRules(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rules, err := chain.GetRules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
|
||||
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_DeleteRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_Flush(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
err := chain.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
84
internal/firewalls/nftables/conn.go
Normal file
84
internal/firewalls/nftables/conn.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
const MaxTableNameLength = 27
|
||||
|
||||
type Conn struct {
|
||||
rawConn *nft.Conn
|
||||
}
|
||||
|
||||
func NewConn() *Conn {
|
||||
return &Conn{
|
||||
rawConn: &nft.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Conn) Raw() *nft.Conn {
|
||||
return this.rawConn
|
||||
}
|
||||
|
||||
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
|
||||
rawTables, err := this.rawConn.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, rawTable := range rawTables {
|
||||
if rawTable.Name == name && rawTable.Family == family {
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTableNotFound
|
||||
}
|
||||
|
||||
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
|
||||
if len(name) > MaxTableNameLength {
|
||||
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawTable = this.rawConn.AddTable(&nft.Table{
|
||||
Family: family,
|
||||
Name: name,
|
||||
})
|
||||
|
||||
err := this.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv6)
|
||||
}
|
||||
|
||||
func (this *Conn) DeleteTable(name string, family TableFamily) error {
|
||||
table, err := this.GetTable(name, family)
|
||||
if err != nil {
|
||||
if err == ErrTableNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.rawConn.DelTable(table.Raw())
|
||||
return this.Commit()
|
||||
}
|
||||
|
||||
func (this *Conn) Commit() error {
|
||||
return this.rawConn.Flush()
|
||||
}
|
||||
78
internal/firewalls/nftables/conn_test.go
Normal file
78
internal/firewalls/nftables/conn_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Test(t *testing.T) {
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
t.Log("ok")
|
||||
return
|
||||
}
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func TestConn_GetTable_NotFound(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_GetTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_AddTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
{
|
||||
table, err := conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
{
|
||||
table, err := conn.AddIPv6Table("test_ipv6")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_DeleteTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
8
internal/firewalls/nftables/element.go
Normal file
8
internal/firewalls/nftables/element.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
type Element struct {
|
||||
}
|
||||
19
internal/firewalls/nftables/errors.go
Normal file
19
internal/firewalls/nftables/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrChainNotFound = errors.New("chain not found")
|
||||
var ErrSetNotFound = errors.New("set not found")
|
||||
var ErrRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
|
||||
}
|
||||
18
internal/firewalls/nftables/family.go
Normal file
18
internal/firewalls/nftables/family.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type TableFamily = nft.TableFamily
|
||||
|
||||
const (
|
||||
TableFamilyINet TableFamily = nft.TableFamilyINet
|
||||
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
|
||||
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
|
||||
TableFamilyARP TableFamily = nft.TableFamilyARP
|
||||
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
|
||||
TableFamilyBridge TableFamily = nft.TableFamilyBridge
|
||||
)
|
||||
51
internal/firewalls/nftables/rule.go
Normal file
51
internal/firewalls/nftables/rule.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
rawRule *nft.Rule
|
||||
}
|
||||
|
||||
func NewRule(rawRule *nft.Rule) *Rule {
|
||||
return &Rule{
|
||||
rawRule: rawRule,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Rule) Raw() *nft.Rule {
|
||||
return this.rawRule
|
||||
}
|
||||
|
||||
func (this *Rule) LookupSetName() string {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Lookup)
|
||||
if ok {
|
||||
return exp.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this *Rule) VerDict() expr.VerdictKind {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Verdict)
|
||||
if ok {
|
||||
return exp.Kind
|
||||
}
|
||||
}
|
||||
|
||||
return -100
|
||||
}
|
||||
|
||||
func (this *Rule) Handle() uint64 {
|
||||
return this.rawRule.Handle
|
||||
}
|
||||
|
||||
func (this *Rule) UserData() []byte {
|
||||
return this.rawRule.UserData
|
||||
}
|
||||
161
internal/firewalls/nftables/set.go
Normal file
161
internal/firewalls/nftables/set.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
nft "github.com/google/nftables"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxSetNameLength = 15
|
||||
|
||||
type SetOptions struct {
|
||||
Id uint32
|
||||
HasTimeout bool
|
||||
Timeout time.Duration
|
||||
KeyType SetDataType
|
||||
DataType SetDataType
|
||||
Constant bool
|
||||
Interval bool
|
||||
Anonymous bool
|
||||
IsMap bool
|
||||
}
|
||||
|
||||
type ElementOptions struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type Set struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
batch *SetBatch
|
||||
}
|
||||
|
||||
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
|
||||
return &Set{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
batch: &SetBatch{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Raw() *nft.Set {
|
||||
return this.rawSet
|
||||
}
|
||||
|
||||
func (this *Set) Name() string {
|
||||
return this.rawSet.Name
|
||||
}
|
||||
|
||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
// retry if exists
|
||||
if strings.Contains(err.Error(), "file exists") {
|
||||
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if deleteErr == nil {
|
||||
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err == nil {
|
||||
err = this.conn.Commit()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.AddElement(ipObj.To4(), options)
|
||||
} else {
|
||||
return this.AddElement(ipObj.To16(), options)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) DeleteElement(key []byte) error {
|
||||
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) DeleteIPElement(ip string) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.DeleteElement(ipObj.To4())
|
||||
} else {
|
||||
return this.DeleteElement(ipObj.To16())
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Batch() *SetBatch {
|
||||
return this.batch
|
||||
}
|
||||
|
||||
func (this *Set) GetIPElements() ([]string, error) {
|
||||
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result = []string{}
|
||||
for _, element := range elements {
|
||||
result = append(result, net.IP(element.Key).String())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// not work current time
|
||||
/**func (this *Set) Flush() error {
|
||||
this.conn.Raw().FlushSet(this.rawSet)
|
||||
return this.conn.Commit()
|
||||
}**/
|
||||
36
internal/firewalls/nftables/set_batch.go
Normal file
36
internal/firewalls/nftables/set_batch.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type SetBatch struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
}
|
||||
|
||||
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) DeleteElement(key []byte) error {
|
||||
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) Commit() error {
|
||||
return this.conn.Commit()
|
||||
}
|
||||
57
internal/firewalls/nftables/set_data_type.go
Normal file
57
internal/firewalls/nftables/set_data_type.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type SetDataType = nft.SetDatatype
|
||||
|
||||
var (
|
||||
TypeInvalid = nft.TypeInvalid
|
||||
TypeVerdict = nft.TypeVerdict
|
||||
TypeNFProto = nft.TypeNFProto
|
||||
TypeBitmask = nft.TypeBitmask
|
||||
TypeInteger = nft.TypeInteger
|
||||
TypeString = nft.TypeString
|
||||
TypeLLAddr = nft.TypeLLAddr
|
||||
TypeIPAddr = nft.TypeIPAddr
|
||||
TypeIP6Addr = nft.TypeIP6Addr
|
||||
TypeEtherAddr = nft.TypeEtherAddr
|
||||
TypeEtherType = nft.TypeEtherType
|
||||
TypeARPOp = nft.TypeARPOp
|
||||
TypeInetProto = nft.TypeInetProto
|
||||
TypeInetService = nft.TypeInetService
|
||||
TypeICMPType = nft.TypeICMPType
|
||||
TypeTCPFlag = nft.TypeTCPFlag
|
||||
TypeDCCPPktType = nft.TypeDCCPPktType
|
||||
TypeMHType = nft.TypeMHType
|
||||
TypeTime = nft.TypeTime
|
||||
TypeMark = nft.TypeMark
|
||||
TypeIFIndex = nft.TypeIFIndex
|
||||
TypeARPHRD = nft.TypeARPHRD
|
||||
TypeRealm = nft.TypeRealm
|
||||
TypeClassID = nft.TypeClassID
|
||||
TypeUID = nft.TypeUID
|
||||
TypeGID = nft.TypeGID
|
||||
TypeCTState = nft.TypeCTState
|
||||
TypeCTDir = nft.TypeCTDir
|
||||
TypeCTStatus = nft.TypeCTStatus
|
||||
TypeICMP6Type = nft.TypeICMP6Type
|
||||
TypeCTLabel = nft.TypeCTLabel
|
||||
TypePktType = nft.TypePktType
|
||||
TypeICMPCode = nft.TypeICMPCode
|
||||
TypeICMPV6Code = nft.TypeICMPV6Code
|
||||
TypeICMPXCode = nft.TypeICMPXCode
|
||||
TypeDevGroup = nft.TypeDevGroup
|
||||
TypeDSCP = nft.TypeDSCP
|
||||
TypeECN = nft.TypeECN
|
||||
TypeFIBAddr = nft.TypeFIBAddr
|
||||
TypeBoolean = nft.TypeBoolean
|
||||
TypeCTEventBit = nft.TypeCTEventBit
|
||||
TypeIFName = nft.TypeIFName
|
||||
TypeIGMPType = nft.TypeIGMPType
|
||||
TypeTimeDate = nft.TypeTimeDate
|
||||
TypeTimeHour = nft.TypeTimeHour
|
||||
TypeTimeDay = nft.TypeTimeDay
|
||||
TypeCGroupV2 = nft.TypeCGroupV2
|
||||
)
|
||||
110
internal/firewalls/nftables/set_test.go
Normal file
110
internal/firewalls/nftables/set_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/mdlayher/netlink"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getIPv4Set(t *testing.T) *nftables.Set {
|
||||
var table = getIPv4Table(t)
|
||||
set, err := table.GetSet("test_ipv4_set")
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func TestSet_AddElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_DeleteElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Batch(t *testing.T) {
|
||||
var batch = getIPv4Set(t).Batch()
|
||||
|
||||
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
//err := batch.DeleteElement(ipData)
|
||||
//if err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := batch.Commit()
|
||||
if err != nil {
|
||||
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Add_Many(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
|
||||
for i := 0; i < 255; i++ {
|
||||
t.Log(i)
|
||||
for j := 0; j < 255; j++ {
|
||||
var ip = "192.167." + types.String(i) + "." + types.String(j)
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if j%10 == 0 {
|
||||
err = set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
/**func TestSet_Flush(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}**/
|
||||
157
internal/firewalls/nftables/table.go
Normal file
157
internal/firewalls/nftables/table.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
}
|
||||
|
||||
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
|
||||
return &Table{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Table) Raw() *nft.Table {
|
||||
return this.rawTable
|
||||
}
|
||||
|
||||
func (this *Table) Name() string {
|
||||
return this.rawTable.Name
|
||||
}
|
||||
|
||||
func (this *Table) Family() TableFamily {
|
||||
return this.rawTable.Family
|
||||
}
|
||||
|
||||
func (this *Table) GetChain(name string) (*Chain, error) {
|
||||
rawChains, err := this.conn.Raw().ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawChain := range rawChains {
|
||||
// must compare table name
|
||||
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrChainNotFound
|
||||
}
|
||||
|
||||
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
|
||||
if len(name) > MaxChainNameLength {
|
||||
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
|
||||
Name: name,
|
||||
Table: this.rawTable,
|
||||
Hooknum: nft.ChainHookInput,
|
||||
Priority: nft.ChainPriorityFilter,
|
||||
Type: nft.ChainTypeFilter,
|
||||
Policy: chainPolicy,
|
||||
})
|
||||
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyAccept
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) AddDropChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyDrop
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) DeleteChain(name string) error {
|
||||
chain, err := this.GetChain(name)
|
||||
if err != nil {
|
||||
if err == ErrChainNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.conn.Raw().DelChain(chain.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) GetSet(name string) (*Set, error) {
|
||||
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
return nil, ErrSetNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
|
||||
if len(name) > MaxSetNameLength {
|
||||
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
|
||||
}
|
||||
|
||||
if options == nil {
|
||||
options = &SetOptions{}
|
||||
}
|
||||
var rawSet = &nft.Set{
|
||||
Table: this.rawTable,
|
||||
ID: options.Id,
|
||||
Name: name,
|
||||
Anonymous: options.Anonymous,
|
||||
Constant: options.Constant,
|
||||
Interval: options.Interval,
|
||||
IsMap: options.IsMap,
|
||||
HasTimeout: options.HasTimeout,
|
||||
Timeout: options.Timeout,
|
||||
KeyType: options.KeyType,
|
||||
DataType: options.DataType,
|
||||
}
|
||||
err := this.conn.Raw().AddSet(rawSet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) DeleteSet(name string) error {
|
||||
set, err := this.GetSet(name)
|
||||
if err != nil {
|
||||
if err == ErrSetNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
this.conn.Raw().DelSet(set.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) Flush() error {
|
||||
this.conn.Raw().FlushTable(this.rawTable)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
140
internal/firewalls/nftables/table_test.go
Normal file
140
internal/firewalls/nftables/table_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Table(t *testing.T) *nftables.Table {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func TestTable_AddChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
|
||||
{
|
||||
chain, err := table.AddChain("test_default_chain", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
{
|
||||
chain, err := table.AddAcceptChain("test_accept_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
|
||||
/**{
|
||||
chain, err := table.AddDropChain("test_drop_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}**/
|
||||
}
|
||||
|
||||
func TestTable_GetChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
t.Log(chainName, ":", "not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(chainName, ":", chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteChain("test_default_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_AddSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
{
|
||||
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
|
||||
HasTimeout: false,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
|
||||
{
|
||||
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
|
||||
HasTimeout: true,
|
||||
//Timeout: 3600 * time.Second,
|
||||
KeyType: nftables.TypeIP6Addr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_GetSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
t.Log(setName, ": not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(setName, ":", set)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteSet("ipv4_black_set")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_Flush(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
@@ -48,7 +48,7 @@ func (this *IPListDB) init() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("CACHE", "create cache dir '"+this.dir+"'")
|
||||
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:"+this.dir+"/ip_list.db?cache=shared&mode=rwc&_journal_mode=WAL&_sync=OFF")
|
||||
|
||||
@@ -3,17 +3,29 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
)
|
||||
|
||||
// AllowIP 检查IP是否被允许访问
|
||||
// 如果一个IP不在任何名单中,则允许访问
|
||||
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
|
||||
// 放行lo
|
||||
if ip == "127.0.0.1" {
|
||||
return true, true
|
||||
}
|
||||
|
||||
var ipLong = utils.IP2Long(ip)
|
||||
if ipLong == 0 {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// check node
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// check white lists
|
||||
if GlobalWhiteIPList.Contains(ipLong) {
|
||||
return true, true
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxQueueSize = 10240
|
||||
const MaxQueueSize = 2048
|
||||
|
||||
// Task 单个指标任务
|
||||
// 数据库存储:
|
||||
@@ -58,7 +58,7 @@ type Task struct {
|
||||
timeMap map[string]zero.Zero // time => bool
|
||||
serverIdMapLocker sync.Mutex
|
||||
|
||||
statsMap map[string]*Stat
|
||||
statsMap map[string]*Stat // 待写入队列,hash => *Stat
|
||||
statsLocker sync.Mutex
|
||||
statsTicker *utils.Ticker
|
||||
}
|
||||
|
||||
@@ -1,32 +1,30 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -111,14 +109,12 @@ func (this *APIStream) loop() error {
|
||||
err = this.handleStatCache(message)
|
||||
case messageconfigs.MessageCodeCleanCache: // 清理缓存
|
||||
err = this.handleCleanCache(message)
|
||||
case messageconfigs.MessageCodePurgeCache: // 删除缓存
|
||||
err = this.handlePurgeCache(message)
|
||||
case messageconfigs.MessageCodePreheatCache: // 预热缓存
|
||||
err = this.handlePreheatCache(message)
|
||||
case messageconfigs.MessageCodeNewNodeTask: // 有新的任务
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
||||
err = this.handleChangeAPINode(message)
|
||||
default:
|
||||
@@ -328,224 +324,6 @@ func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除缓存
|
||||
func (this *APIStream) handlePurgeCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.PurgeCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
// WEBP缓存
|
||||
if msg.Type == "file" {
|
||||
var keys = msg.Keys
|
||||
for _, key := range keys {
|
||||
keys = append(keys,
|
||||
key+caches.SuffixMethod+"HEAD",
|
||||
key+caches.SuffixWebP,
|
||||
key+caches.SuffixPartial,
|
||||
)
|
||||
// TODO 根据实际缓存的内容进行组合
|
||||
for _, encoding := range compressions.AllEncodings() {
|
||||
keys = append(keys, key+caches.SuffixCompression+encoding)
|
||||
keys = append(keys, key+caches.SuffixWebP+caches.SuffixCompression+encoding)
|
||||
}
|
||||
}
|
||||
msg.Keys = keys
|
||||
}
|
||||
|
||||
err = storage.Purge(msg.Keys, msg.Type)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "purge keys failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 预热缓存
|
||||
func (this *APIStream) handlePreheatCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.PreheatCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
if len(msg.Keys) == 0 {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
return nil
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(msg.Keys))
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second, // TODO 可以设置请求超时时间
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.Dial(network, "127.0.0.1:"+port)
|
||||
},
|
||||
MaxIdleConns: 4096,
|
||||
MaxIdleConnsPerHost: 32,
|
||||
MaxConnsPerHost: 32,
|
||||
IdleConnTimeout: 2 * time.Minute,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 0,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
defer client.CloseIdleConnections()
|
||||
errorMessages := []string{}
|
||||
locker := sync.Mutex{}
|
||||
for _, key := range msg.Keys {
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, key, nil)
|
||||
if err != nil {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "invalid url: "+key+": "+err.Error())
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO 可以在管理界面自定义Header
|
||||
req.Header.Set("X-Cache-Action", "preheat")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36")
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br") // TODO 这里需要记录下缓存是否为gzip的
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "request failed: "+key+": "+err.Error())
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "request failed: "+key+": status code '"+strconv.Itoa(resp.StatusCode)+"'")
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 检查最大内容长度
|
||||
// TODO 需要解决Chunked Transfer Encoding的长度判断问题
|
||||
maxSize := storage.Policy().MaxSizeBytes()
|
||||
if maxSize > 0 && resp.ContentLength > maxSize {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "request failed: the content is too larger than policy setting")
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
expiredAt := time.Now().Unix() + 8600
|
||||
writer, err := storage.OpenWriter(key, expiredAt, 200, resp.ContentLength, -1, false) // TODO 可以设置缓存过期时间
|
||||
if err != nil {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "open cache writer failed: "+key+": "+err.Error())
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 16*1024)
|
||||
isClosed := false
|
||||
|
||||
// 写入Header
|
||||
for k, v := range resp.Header {
|
||||
for _, v1 := range v {
|
||||
_, err = writer.WriteHeader([]byte(k + ":" + v1 + "\n"))
|
||||
if err != nil {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "write failed: "+key+": "+err.Error())
|
||||
locker.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入Body
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
_, writerErr := writer.Write(buf[:n])
|
||||
if writerErr != nil {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "write failed: "+key+": "+writerErr.Error())
|
||||
locker.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
|
||||
err = writer.Close()
|
||||
if err == nil {
|
||||
storage.AddToList(&caches.Item{
|
||||
Type: writer.ItemType(),
|
||||
Key: key,
|
||||
ExpiredAt: expiredAt,
|
||||
})
|
||||
}
|
||||
isClosed = true
|
||||
} else {
|
||||
locker.Lock()
|
||||
errorMessages = append(errorMessages, "read url failed: "+key+": "+err.Error())
|
||||
locker.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isClosed {
|
||||
_ = writer.Close()
|
||||
}
|
||||
}(key)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if len(errorMessages) > 0 {
|
||||
this.replyFail(message.RequestId, strings.Join(errorMessages, ", "))
|
||||
return nil
|
||||
}
|
||||
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理配置变化
|
||||
func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error {
|
||||
select {
|
||||
@@ -569,7 +347,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := utils.NewCommandExecutor()
|
||||
var cmd = utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
@@ -585,6 +363,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := exec.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDOSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改API地址
|
||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
||||
config, err := configs.LoadAPIConfig()
|
||||
@@ -660,6 +495,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功并包含数据
|
||||
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
|
||||
}
|
||||
|
||||
// 获取缓存存取对象
|
||||
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
|
||||
cachePolicy := &serverconfigs.HTTPCachePolicy{}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
@@ -21,8 +20,7 @@ import (
|
||||
|
||||
// ClientConn 客户端连接
|
||||
type ClientConn struct {
|
||||
once sync.Once
|
||||
globalLimiter *ratelimit.Counter
|
||||
once sync.Once
|
||||
|
||||
isTLS bool
|
||||
hasDeadline bool
|
||||
@@ -33,7 +31,7 @@ type ClientConn struct {
|
||||
BaseClientConn
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
|
||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
|
||||
if quickClose {
|
||||
// TCP
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
|
||||
}
|
||||
}
|
||||
|
||||
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
|
||||
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
|
||||
|
||||
err := this.rawConn.Close()
|
||||
|
||||
// 全局并发数限制
|
||||
this.once.Do(func() {
|
||||
if this.globalLimiter != nil {
|
||||
this.globalLimiter.Release()
|
||||
}
|
||||
})
|
||||
|
||||
// 单个服务并发数限制
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
@@ -137,7 +128,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
|
||||
var ip = this.RawIP()
|
||||
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
|
||||
var timestamp = utils.NextMinuteUnixTime()
|
||||
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp)
|
||||
var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp, true)
|
||||
var minAttempts = synFloodConfig.MinAttempts
|
||||
if minAttempts < 5 {
|
||||
minAttempts = 5
|
||||
|
||||
@@ -3,16 +3,13 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
|
||||
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
type ClientListener struct {
|
||||
rawListener net.Listener
|
||||
@@ -36,13 +33,8 @@ func (this *ClientListener) IsTLS() bool {
|
||||
}
|
||||
|
||||
func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
// 限制并发连接数
|
||||
var limiter = sharedConnectionsLimiter
|
||||
limiter.Ack()
|
||||
|
||||
conn, err := this.rawListener.Accept()
|
||||
if err != nil {
|
||||
limiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -50,22 +42,30 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
canGoNext, _ := iplibrary.AllowIP(ip, 0)
|
||||
var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
|
||||
if !canGoNext ||
|
||||
(!waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)) {
|
||||
if !canGoNext || beingDenied {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
limiter.Release()
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
if beingDenied {
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
_ = fw.DropSourceIP(ip, 60)
|
||||
}
|
||||
}
|
||||
|
||||
return this.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose), nil
|
||||
}
|
||||
|
||||
func (this *ClientListener) Close() error {
|
||||
|
||||
7
internal/nodes/conn_linger.go
Normal file
7
internal/nodes/conn_linger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
type LingerConn interface {
|
||||
SetLinger(sec int) error
|
||||
}
|
||||
@@ -80,6 +80,13 @@ Loop:
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送到本地
|
||||
if sharedHTTPAccessLogViewer.HasConns() {
|
||||
for _, accessLog := range accessLogs {
|
||||
sharedHTTPAccessLogViewer.Send(accessLog)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送到API
|
||||
if this.rpcClient == nil {
|
||||
client, err := rpc.SharedRPC()
|
||||
|
||||
116
internal/nodes/http_access_log_viewer.go
Normal file
116
internal/nodes/http_access_log_viewer.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var sharedHTTPAccessLogViewer = NewHTTPAccessLogViewer()
|
||||
|
||||
// HTTPAccessLogViewer 本地访问日志浏览器
|
||||
type HTTPAccessLogViewer struct {
|
||||
sockFile string
|
||||
|
||||
listener net.Listener
|
||||
connMap map[int64]net.Conn // connId => net.Conn
|
||||
connId int64
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
// NewHTTPAccessLogViewer 获取新对象
|
||||
func NewHTTPAccessLogViewer() *HTTPAccessLogViewer {
|
||||
return &HTTPAccessLogViewer{
|
||||
sockFile: os.TempDir() + "/" + teaconst.AccessLogSockName,
|
||||
connMap: map[int64]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动
|
||||
func (this *HTTPAccessLogViewer) Start() error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if this.listener == nil {
|
||||
// remove if exists
|
||||
_ = os.Remove(this.sockFile)
|
||||
|
||||
// start listening
|
||||
listener, err := net.Listen("unix", this.sockFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.listener = listener
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := this.listener.Accept()
|
||||
if err != nil {
|
||||
remotelogs.Error("ACCESS_LOG", "start local reading failed: "+err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
var connId = this.nextConnId()
|
||||
this.connMap[connId] = conn
|
||||
go func() {
|
||||
this.startReading(conn, connId)
|
||||
}()
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasConns 检查是否有连接
|
||||
func (this *HTTPAccessLogViewer) HasConns() bool {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
return len(this.connMap) > 0
|
||||
}
|
||||
|
||||
// Send 发送日志
|
||||
func (this *HTTPAccessLogViewer) Send(accessLog *pb.HTTPAccessLog) {
|
||||
var conns = []net.Conn{}
|
||||
this.locker.Lock()
|
||||
for _, conn := range this.connMap {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
if len(conns) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, conn := range conns {
|
||||
// ignore error
|
||||
_, _ = conn.Write([]byte(accessLog.RemoteAddr + " [" + accessLog.TimeLocal + "] \"" + accessLog.RequestMethod + " " + accessLog.Scheme + "://" + accessLog.Host + accessLog.RequestURI + " " + accessLog.Proto + "\" " + types.String(accessLog.Status) + " - " + fmt.Sprintf("%.2fms", accessLog.RequestTime*1000) + "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPAccessLogViewer) nextConnId() int64 {
|
||||
return atomic.AddInt64(&this.connId, 1)
|
||||
}
|
||||
|
||||
func (this *HTTPAccessLogViewer) startReading(conn net.Conn, connId int64) {
|
||||
var buf = make([]byte, 1024)
|
||||
for {
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
this.locker.Lock()
|
||||
delete(this.connMap, connId)
|
||||
this.locker.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
247
internal/nodes/http_cache_task_manager.go
Normal file
247
internal/nodes/http_cache_task_manager.go
Normal file
@@ -0,0 +1,247 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventStart, func() {
|
||||
goman.New(func() {
|
||||
SharedHTTPCacheTaskManager.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var SharedHTTPCacheTaskManager = NewHTTPCacheTaskManager()
|
||||
|
||||
// HTTPCacheTaskManager 缓存任务管理
|
||||
type HTTPCacheTaskManager struct {
|
||||
ticker *time.Ticker
|
||||
httpClient *http.Client
|
||||
protocolReg *regexp.Regexp
|
||||
|
||||
taskQueue chan *pb.PurgeServerCacheRequest
|
||||
}
|
||||
|
||||
func NewHTTPCacheTaskManager() *HTTPCacheTaskManager {
|
||||
var duration = 30 * time.Second
|
||||
if Tea.IsTesting() {
|
||||
duration = 10 * time.Second
|
||||
}
|
||||
return &HTTPCacheTaskManager{
|
||||
ticker: time.NewTicker(duration),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Minute, // TODO 可以设置请求超时时间
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.Dial(network, "127.0.0.1:"+port)
|
||||
},
|
||||
MaxIdleConns: 128,
|
||||
MaxIdleConnsPerHost: 32,
|
||||
MaxConnsPerHost: 32,
|
||||
IdleConnTimeout: 2 * time.Minute,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 0,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
protocolReg: regexp.MustCompile(`^(?i)(http|https)://`),
|
||||
taskQueue: make(chan *pb.PurgeServerCacheRequest, 1024),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) Start() {
|
||||
// task queue
|
||||
goman.New(func() {
|
||||
rpcClient, _ := rpc.SharedRPC()
|
||||
|
||||
if rpcClient != nil {
|
||||
for taskReq := range this.taskQueue {
|
||||
_, err := rpcClient.ServerRPC().PurgeServerCache(rpcClient.Context(), taskReq)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "create purge task failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Loop
|
||||
for range this.ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "execute task failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) Loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.HTTPCacheTaskKeyRPC().FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var keys = resp.HttpCacheTaskKeys
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pbResults = []*pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{}
|
||||
|
||||
for _, key := range keys {
|
||||
err = this.processKey(key)
|
||||
|
||||
var pbResult = &pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{
|
||||
Id: key.Id,
|
||||
NodeClusterId: key.NodeClusterId,
|
||||
Error: "",
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
pbResult.Error = err.Error()
|
||||
}
|
||||
pbResults = append(pbResults, pbResult)
|
||||
}
|
||||
|
||||
_, err = rpcClient.HTTPCacheTaskKeyRPC().UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) PushTaskKeys(keys []string) {
|
||||
select {
|
||||
case this.taskQueue <- &pb.PurgeServerCacheRequest{
|
||||
Keys: keys,
|
||||
Prefixes: nil,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
|
||||
switch key.Type {
|
||||
case "purge":
|
||||
var storages = caches.SharedManager.FindAllStorages()
|
||||
for _, storage := range storages {
|
||||
switch key.KeyType {
|
||||
case "key":
|
||||
var cacheKeys = []string{key.Key}
|
||||
if strings.HasPrefix(key.Key, "http://") {
|
||||
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "http://", "https://", 1))
|
||||
} else if strings.HasPrefix(key.Key, "https://") {
|
||||
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "https://", "http://", 1))
|
||||
}
|
||||
|
||||
// TODO 提升效率
|
||||
for _, cacheKey := range cacheKeys {
|
||||
var subKeys = []string{
|
||||
cacheKey,
|
||||
cacheKey + caches.SuffixMethod + "HEAD",
|
||||
cacheKey + caches.SuffixWebP,
|
||||
cacheKey + caches.SuffixPartial,
|
||||
}
|
||||
// TODO 根据实际缓存的内容进行组合
|
||||
for _, encoding := range compressions.AllEncodings() {
|
||||
subKeys = append(subKeys, cacheKey+caches.SuffixCompression+encoding)
|
||||
subKeys = append(subKeys, cacheKey+caches.SuffixWebP+caches.SuffixCompression+encoding)
|
||||
}
|
||||
|
||||
err := storage.Purge(subKeys, "file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "prefix":
|
||||
var prefixes = []string{key.Key}
|
||||
if strings.HasPrefix(key.Key, "http://") {
|
||||
prefixes = append(prefixes, strings.Replace(key.Key, "http://", "https://", 1))
|
||||
} else if strings.HasPrefix(key.Key, "https://") {
|
||||
prefixes = append(prefixes, strings.Replace(key.Key, "https://", "http://", 1))
|
||||
}
|
||||
|
||||
err := storage.Purge(prefixes, "dir")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
case "fetch":
|
||||
err := this.fetchKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid operation type '" + key.Type + "'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO 增加失败重试
|
||||
// TODO 使用并发操作
|
||||
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
|
||||
var fullKey = key.Key
|
||||
if !this.protocolReg.MatchString(fullKey) {
|
||||
fullKey = "https://" + fullKey
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, fullKey, nil)
|
||||
if err != nil {
|
||||
return errors.New("invalid url: " + fullKey + ": " + err.Error())
|
||||
}
|
||||
|
||||
// TODO 可以在管理界面自定义Header
|
||||
req.Header.Set("X-Edge-Cache-Action", "fetch")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36") // TODO 可以定义
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
||||
resp, err := this.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errors.New("request failed: " + fullKey + ": " + err.Error())
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取内容,以便于生成缓存
|
||||
_, _ = io.Copy(ioutil.Discard, resp.Body)
|
||||
|
||||
// 处理502
|
||||
if resp.StatusCode == http.StatusBadGateway {
|
||||
return errors.New("read origin site timeout")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
25
internal/nodes/http_cache_task_manager_test.go
Normal file
25
internal/nodes/http_cache_task_manager_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/nodes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPCacheTaskManager_Loop(t *testing.T) {
|
||||
// initialize cache policies
|
||||
config, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
|
||||
|
||||
var manager = nodes.NewHTTPCacheTaskManager()
|
||||
err = manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -104,39 +104,41 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
}
|
||||
}
|
||||
|
||||
var transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 支持TOA的连接
|
||||
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
|
||||
if conn != nil || err != nil {
|
||||
return conn, err
|
||||
}
|
||||
var transport = &HTTPClientTransport{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 支持TOA的连接
|
||||
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
|
||||
if conn != nil || err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// 普通的连接
|
||||
conn, err = (&net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
}).DialContext(ctx, network, originAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 普通的连接
|
||||
conn, err = (&net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
}).DialContext(ctx, network, originAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理PROXY protocol
|
||||
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 处理PROXY protocol
|
||||
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
return conn, nil
|
||||
},
|
||||
MaxIdleConns: 0,
|
||||
MaxIdleConnsPerHost: idleConns,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
IdleConnTimeout: idleTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: nil,
|
||||
},
|
||||
MaxIdleConns: 0,
|
||||
MaxIdleConnsPerHost: idleConns,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
IdleConnTimeout: idleTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: nil,
|
||||
}
|
||||
|
||||
rawClient = &http.Client{
|
||||
|
||||
26
internal/nodes/http_client_transport.go
Normal file
26
internal/nodes/http_client_transport.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const emptyHTTPLocation = "/$EmptyHTTPLocation$"
|
||||
|
||||
type HTTPClientTransport struct {
|
||||
*http.Transport
|
||||
}
|
||||
|
||||
func (this *HTTPClientTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := this.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// 检查在跳转相关状态中Location是否存在
|
||||
if httpStatusIsRedirect(resp.StatusCode) && len(resp.Header.Get("Location")) == 0 {
|
||||
resp.Header.Set("Location", emptyHTTPLocation)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -170,9 +170,10 @@ func (this *HTTPRequest) Do() {
|
||||
// ACME
|
||||
// TODO 需要配置是否启用ACME检测
|
||||
if strings.HasPrefix(this.rawURI, "/.well-known/acme-challenge/") {
|
||||
this.doACME()
|
||||
this.doEnd()
|
||||
return
|
||||
if this.doACME() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,10 +266,19 @@ func (this *HTTPRequest) doBegin() {
|
||||
}
|
||||
|
||||
// UAM
|
||||
if !isHealthCheck && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
if !isHealthCheck {
|
||||
if this.web.UAM != nil {
|
||||
if this.web.UAM.IsOn {
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn {
|
||||
if this.doUAM() {
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,6 +531,11 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
}
|
||||
}
|
||||
|
||||
// UAM
|
||||
if web.UAM != nil && (web.UAM.IsPrior || isTop) {
|
||||
this.web.UAM = web.UAM
|
||||
}
|
||||
|
||||
// 重写规则
|
||||
if len(web.RewriteRefs) > 0 {
|
||||
for index, ref := range web.RewriteRefs {
|
||||
@@ -1033,7 +1048,7 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
|
||||
}
|
||||
|
||||
// X-Forwarded-For
|
||||
forwardedFor := this.RawReq.Header.Get("X-Forwarded-For")
|
||||
var forwardedFor = this.RawReq.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
commaIndex := strings.Index(forwardedFor, ",")
|
||||
if commaIndex > 0 {
|
||||
@@ -1447,7 +1462,12 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||
header["X-Forwarded-For"] = []string{strings.Join(forwardedFor, ", ") + ", " + remoteAddr}
|
||||
}
|
||||
} else {
|
||||
header["X-Forwarded-For"] = []string{remoteAddr}
|
||||
var clientRemoteAddr = this.requestRemoteAddr(true)
|
||||
if len(clientRemoteAddr) > 0 && clientRemoteAddr != remoteAddr {
|
||||
header["X-Forwarded-For"] = []string{clientRemoteAddr + ", " + remoteAddr}
|
||||
} else {
|
||||
header["X-Forwarded-For"] = []string{remoteAddr}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,34 +4,36 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doACME() {
|
||||
func (this *HTTPRequest) doACME() (shouldStop bool) {
|
||||
// TODO 对请求进行校验,防止恶意攻击
|
||||
|
||||
token := filepath.Base(this.RawReq.URL.Path)
|
||||
var token = filepath.Base(this.RawReq.URL.Path)
|
||||
if token == "acme-challenge" || len(token) <= 32 {
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("RPC", "[ACME]rpc failed: "+err.Error())
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
keyResp, err := rpcClient.ACMEAuthenticationRPC().FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
|
||||
if err != nil {
|
||||
remotelogs.Error("RPC", "[ACME]read key for token failed: "+err.Error())
|
||||
return
|
||||
return false
|
||||
}
|
||||
if len(keyResp.Key) == 0 {
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
} else {
|
||||
this.writer.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = this.writer.WriteString(keyResp.Key)
|
||||
return false
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "ACME")
|
||||
|
||||
this.writer.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = this.writer.WriteString(keyResp.Key)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3,12 +3,9 @@ package nodes
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -33,11 +30,6 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是否在预热
|
||||
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Cache-Action") == "preheat" {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加 X-Cache Header
|
||||
var addStatusHeader = this.web.Cache.AddStatusHeader
|
||||
if addStatusHeader {
|
||||
@@ -89,6 +81,12 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否正在Purge
|
||||
var isPurging = this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey
|
||||
if isPurging {
|
||||
this.RawReq.Method = http.MethodGet
|
||||
}
|
||||
|
||||
// 校验请求
|
||||
if !this.cacheRef.MatchRequest(this.RawReq) {
|
||||
this.cacheRef = nil
|
||||
@@ -136,8 +134,13 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
}
|
||||
this.writer.cacheStorage = storage
|
||||
|
||||
// 如果正在预热,则不读取缓存,等待下一个步骤重新生成
|
||||
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Edge-Cache-Action") == "fetch" {
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是否在Purge
|
||||
if this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey {
|
||||
if isPurging {
|
||||
this.varMapping["cache.status"] = "PURGE"
|
||||
|
||||
var subKeys = []string{
|
||||
@@ -159,22 +162,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
}
|
||||
|
||||
// 通过API节点清除别节点上的的Key
|
||||
// TODO 改为队列,不需要每个请求都使用goroutine
|
||||
goman.New(func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err == nil {
|
||||
for _, rpcServerService := range rpcClient.ServerRPCList() {
|
||||
_, err = rpcServerService.PurgeServerCache(rpcClient.Context(), &pb.PurgeServerCacheRequest{
|
||||
Domains: []string{this.ReqHost},
|
||||
Keys: []string{key},
|
||||
Prefixes: nil,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
SharedHTTPCacheTaskManager.PushTaskKeys([]string{key})
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -248,6 +236,11 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
if reader == nil {
|
||||
reader, err = storage.OpenReader(key, useStale, false)
|
||||
if err != nil && this.cacheRef.AllowPartialContent {
|
||||
// 尝试读取分片的缓存内容
|
||||
if len(rangeHeader) == 0 {
|
||||
// 默认读取开头
|
||||
rangeHeader = "bytes=0-"
|
||||
}
|
||||
pReader, ranges := this.tryPartialReader(storage, key, useStale, rangeHeader)
|
||||
if pReader != nil {
|
||||
isPartialCache = true
|
||||
|
||||
@@ -21,8 +21,13 @@ func (this *HTTPRequest) doHealthCheck(key string, isHealthCheck *bool) (stop bo
|
||||
}
|
||||
*isHealthCheck = true
|
||||
|
||||
if !data.GetBool("accessLogIsOn") {
|
||||
this.disableLog = true
|
||||
}
|
||||
|
||||
if data.GetBool("onlyBasicRequest") {
|
||||
return true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,9 +2,18 @@
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
// 是否在全局名单中
|
||||
_, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查请求Body尺寸
|
||||
// TODO 处理分片提交的内容
|
||||
if this.web.RequestLimit.MaxBodyBytes() > 0 &&
|
||||
|
||||
@@ -282,16 +282,32 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
// 替换Location中的源站地址
|
||||
var locationHeader = resp.Header.Get("Location")
|
||||
if len(locationHeader) > 0 {
|
||||
locationURL, err := url.Parse(locationHeader)
|
||||
if err == nil && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
|
||||
locationURL.Host = this.ReqHost
|
||||
if this.IsHTTP {
|
||||
locationURL.Scheme = "http"
|
||||
} else if this.IsHTTPS {
|
||||
locationURL.Scheme = "https"
|
||||
}
|
||||
// 空Location处理
|
||||
if locationHeader == emptyHTTPLocation {
|
||||
resp.Header.Del("Location")
|
||||
} else {
|
||||
// 自动修正Location中的源站地址
|
||||
locationURL, err := url.Parse(locationHeader)
|
||||
if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
|
||||
locationURL.Host = this.ReqHost
|
||||
|
||||
resp.Header.Set("Location", locationURL.String())
|
||||
var oldScheme = locationURL.Scheme
|
||||
|
||||
// 尝试和当前Scheme一致
|
||||
if this.IsHTTP {
|
||||
locationURL.Scheme = "http"
|
||||
} else if this.IsHTTPS {
|
||||
locationURL.Scheme = "https"
|
||||
}
|
||||
|
||||
// 如果和当前URL一样,则可能是http -> https,防止无限循环
|
||||
if locationURL.String() == this.URL() {
|
||||
locationURL.Scheme = oldScheme
|
||||
resp.Header.Set("Location", locationURL.String())
|
||||
} else {
|
||||
resp.Header.Set("Location", locationURL.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +328,12 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
|
||||
// 是否有内容
|
||||
if resp.ContentLength == 0 && len(resp.TransferEncoding) == 0 {
|
||||
// 即使内容为0,也需要读取一次,以便于触发相关事件
|
||||
var buf = utils.BytePool4k.Get()
|
||||
_, _ = io.CopyBuffer(this.writer, resp.Body, buf)
|
||||
utils.BytePool4k.Put(buf)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
this.writer.SetOk()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
rootDir := this.web.Root.Dir
|
||||
var rootDir = this.web.Root.Dir
|
||||
if this.web.Root.HasVariables() {
|
||||
rootDir = this.Format(rootDir)
|
||||
}
|
||||
@@ -59,9 +59,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
rootDir = Tea.Root + Tea.DS + rootDir
|
||||
}
|
||||
|
||||
requestPath := this.uri
|
||||
var requestPath = this.uri
|
||||
|
||||
questionMarkIndex := strings.Index(this.uri, "?")
|
||||
var questionMarkIndex = strings.Index(this.uri, "?")
|
||||
if questionMarkIndex > -1 {
|
||||
requestPath = this.uri[:questionMarkIndex]
|
||||
}
|
||||
@@ -75,7 +75,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
if err == nil {
|
||||
requestPath = p
|
||||
} else {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,8 +94,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
}
|
||||
|
||||
filename := strings.Replace(requestPath, "/", Tea.DS, -1)
|
||||
filePath := ""
|
||||
var filename = strings.Replace(requestPath, "/", Tea.DS, -1)
|
||||
var filePath = ""
|
||||
if len(filename) > 0 && filename[0:1] == Tea.DS {
|
||||
filePath = rootDir + filename
|
||||
} else {
|
||||
@@ -113,7 +115,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -142,7 +146,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -152,24 +158,24 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
|
||||
// 响应header
|
||||
respHeader := this.writer.Header()
|
||||
var respHeader = this.writer.Header()
|
||||
|
||||
// mime type
|
||||
contentType := ""
|
||||
var contentType = ""
|
||||
if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
|
||||
ext := filepath.Ext(filePath)
|
||||
var ext = filepath.Ext(filePath)
|
||||
if len(ext) > 0 {
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if len(mimeType) > 0 {
|
||||
semicolonIndex := strings.Index(mimeType, ";")
|
||||
mimeTypeKey := mimeType
|
||||
var semicolonIndex = strings.Index(mimeType, ";")
|
||||
var mimeTypeKey = mimeType
|
||||
if semicolonIndex > 0 {
|
||||
mimeTypeKey = mimeType[:semicolonIndex]
|
||||
}
|
||||
|
||||
if _, found := textMimeMap[mimeTypeKey]; found {
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
charset := this.web.Charset.Charset
|
||||
var charset = this.web.Charset.Charset
|
||||
if this.web.Charset.IsUpper {
|
||||
charset = strings.ToUpper(charset)
|
||||
}
|
||||
@@ -197,7 +203,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
|
||||
// 支持 ETag
|
||||
eTag := "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
|
||||
var eTag = "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
|
||||
if len(respHeader.Get("ETag")) == 0 {
|
||||
respHeader.Set("ETag", eTag)
|
||||
}
|
||||
@@ -227,7 +233,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
// 支持Range
|
||||
respHeader.Set("Accept-Ranges", "bytes")
|
||||
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
|
||||
supportRange := true
|
||||
var supportRange = true
|
||||
if ok {
|
||||
supportRange = false
|
||||
for _, v := range ifRangeHeaders {
|
||||
@@ -244,7 +250,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
// 支持Range
|
||||
var ranges = []rangeutils.Range{}
|
||||
if supportRange {
|
||||
contentRange := this.RawReq.Header.Get("Range")
|
||||
var contentRange = this.RawReq.Header.Get("Range")
|
||||
if len(contentRange) > 0 {
|
||||
if fileSize == 0 {
|
||||
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
|
||||
@@ -277,7 +283,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
|
||||
}
|
||||
|
||||
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||
fileReader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, true)
|
||||
return true
|
||||
@@ -291,12 +297,16 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
this.cacheRef = nil // 不支持缓存
|
||||
}
|
||||
|
||||
this.writer.Prepare(nil, fileSize, http.StatusOK, true)
|
||||
var resp = &http.Response{
|
||||
ContentLength: fileSize,
|
||||
Body: fileReader,
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
this.writer.Prepare(resp, fileSize, http.StatusOK, true)
|
||||
|
||||
pool := this.bytePool(fileSize)
|
||||
buf := pool.Get()
|
||||
var pool = this.bytePool(fileSize)
|
||||
var buf = pool.Get()
|
||||
defer func() {
|
||||
_ = reader.Close()
|
||||
pool.Put(buf)
|
||||
}()
|
||||
|
||||
@@ -304,12 +314,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize)))
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
ok, err := httpRequestReadRange(reader, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
|
||||
ok, err := httpRequestReadRange(resp.Body, buf, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
|
||||
_, err := this.writer.Write(buf[:n])
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
@@ -318,7 +330,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
return true
|
||||
}
|
||||
} else if len(ranges) > 1 {
|
||||
boundary := httpRequestGenBoundary()
|
||||
var boundary = httpRequestGenBoundary()
|
||||
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
|
||||
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
@@ -330,30 +342,38 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
|
||||
}
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n")
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if len(contentType) > 0 {
|
||||
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := httpRequestReadRange(reader, buf, r.Start(), r.End(), func(buf []byte, n int) error {
|
||||
ok, err := httpRequestReadRange(resp.Body, buf, r.Start(), r.End(), func(buf []byte, n int) error {
|
||||
_, err := this.writer.Write(buf[:n])
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
@@ -365,14 +385,17 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, reader, buf)
|
||||
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -400,7 +423,9 @@ func (this *HTTPRequest) findIndexFile(dir string) (filename string, stat os.Fil
|
||||
if strings.Contains(index, "*") {
|
||||
indexFiles, err := filepath.Glob(dir + Tea.DS + index)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
this.addError(err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
}
|
||||
|
||||
// 规则测试
|
||||
w := sharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
return
|
||||
}
|
||||
|
||||
w := sharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -20,7 +19,7 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
// TODO 实现handshakeTimeout
|
||||
|
||||
// 校验来源
|
||||
requestOrigin := this.RawReq.Header.Get("Origin")
|
||||
var requestOrigin = this.RawReq.Header.Get("Origin")
|
||||
if len(requestOrigin) > 0 {
|
||||
u, err := url.Parse(requestOrigin)
|
||||
if err == nil {
|
||||
@@ -34,7 +33,7 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
|
||||
// 设置指定的来源域
|
||||
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
||||
newRequestOrigin := this.web.Websocket.RequestOrigin
|
||||
var newRequestOrigin = this.web.Websocket.RequestOrigin
|
||||
if this.web.Websocket.RequestOriginHasVariables() {
|
||||
newRequestOrigin = this.Format(newRequestOrigin)
|
||||
}
|
||||
@@ -45,8 +44,21 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, false)
|
||||
|
||||
// 增加失败次数
|
||||
SharedOriginStateManager.Fail(this.origin, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !this.origin.IsOk {
|
||||
SharedOriginStateManager.Success(this.origin, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
@@ -66,7 +78,7 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
goman.New(func() {
|
||||
go func() {
|
||||
var buf = utils.BytePool4k.Get()
|
||||
defer utils.BytePool4k.Put(buf)
|
||||
for {
|
||||
@@ -84,6 +96,6 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
}
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
})
|
||||
}()
|
||||
_, _ = io.Copy(originConn, clientConn)
|
||||
}
|
||||
|
||||
@@ -312,9 +312,10 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
|
||||
|
||||
if !caches.CanIgnoreErr(err) {
|
||||
remotelogs.Error("HTTP_WRITER", "write cache failed: "+err.Error())
|
||||
this.Header().Set("X-Cache", "BYPASS, write cache failed")
|
||||
} else {
|
||||
this.Header().Set("X-Cache", "BYPASS, "+err.Error())
|
||||
}
|
||||
|
||||
this.Header().Set("X-Cache", "BYPASS, too many requests")
|
||||
return
|
||||
}
|
||||
this.cacheWriter = cacheWriter
|
||||
@@ -448,7 +449,9 @@ func (this *HTTPWriter) PrepareCache(resp *http.Response, size int64) {
|
||||
this.rawReader = cacheReader
|
||||
|
||||
cacheReader.OnFail(func(err error) {
|
||||
_ = this.cacheWriter.Discard()
|
||||
if this.cacheWriter != nil {
|
||||
_ = this.cacheWriter.Discard()
|
||||
}
|
||||
this.cacheWriter = nil
|
||||
})
|
||||
cacheReader.OnEOF(func() {
|
||||
@@ -836,7 +839,7 @@ func (this *HTTPWriter) HeaderData() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &http.Response{}
|
||||
var resp = &http.Response{}
|
||||
resp.Header = this.Header()
|
||||
if this.statusCode == 0 {
|
||||
this.statusCode = http.StatusOK
|
||||
@@ -859,6 +862,70 @@ func (this *HTTPWriter) SetOk() {
|
||||
|
||||
// Close 关闭
|
||||
func (this *HTTPWriter) Close() {
|
||||
this.finishWebP()
|
||||
this.finishRequest()
|
||||
this.finishCache()
|
||||
this.finishCompression()
|
||||
|
||||
// 统计
|
||||
if this.sentBodyBytes == 0 {
|
||||
this.sentBodyBytes = this.counterWriter.TotalBytes()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack Hijack
|
||||
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
hijack, ok := this.rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
return hijack.Hijack()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Flush Flush
|
||||
func (this *HTTPWriter) Flush() {
|
||||
flusher, ok := this.rawWriter.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// DelayRead 是否延迟读取Reader
|
||||
func (this *HTTPWriter) DelayRead() bool {
|
||||
return this.delayRead
|
||||
}
|
||||
|
||||
// 计算stale时长
|
||||
func (this *HTTPWriter) calculateStaleLife() int {
|
||||
var staleLife = 600 // TODO 可以在缓存策略里设置此时间
|
||||
var staleConfig = this.req.web.Cache.Stale
|
||||
if staleConfig != nil && staleConfig.IsOn {
|
||||
// 从Header中读取stale-if-error
|
||||
var isDefinedInHeader = false
|
||||
if staleConfig.SupportStaleIfErrorHeader {
|
||||
var cacheControl = this.GetHeader("Cache-Control")
|
||||
var pieces = strings.Split(cacheControl, ",")
|
||||
for _, piece := range pieces {
|
||||
var eqIndex = strings.Index(piece, "=")
|
||||
if eqIndex > 0 && strings.TrimSpace(piece[:eqIndex]) == "stale-if-error" {
|
||||
// 这里预示着如果stale-if-error=0,可以关闭stale功能
|
||||
staleLife = types.Int(strings.TrimSpace(piece[eqIndex+1:]))
|
||||
isDefinedInHeader = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义
|
||||
if !isDefinedInHeader && staleConfig.Life != nil {
|
||||
staleLife = types.Int(staleConfig.Life.Duration().Seconds())
|
||||
}
|
||||
}
|
||||
return staleLife
|
||||
}
|
||||
|
||||
// 结束WebP
|
||||
func (this *HTTPWriter) finishWebP() {
|
||||
// 处理WebP
|
||||
if this.webpIsEncoding {
|
||||
var webpCacheWriter caches.Writer
|
||||
@@ -919,6 +986,7 @@ func (this *HTTPWriter) Close() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// 发生了错误终止处理
|
||||
return
|
||||
}
|
||||
|
||||
@@ -948,7 +1016,7 @@ func (this *HTTPWriter) Close() {
|
||||
//webpConfig.SetLossless(1)
|
||||
webpConfig.SetQuality(f)
|
||||
|
||||
timeline := 0
|
||||
var timeline = 0
|
||||
|
||||
for i, img := range gifImage.Image {
|
||||
err = anim.AddFrame(img, timeline, webpConfig)
|
||||
@@ -988,15 +1056,10 @@ func (this *HTTPWriter) Close() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if this.writer != nil {
|
||||
_ = this.writer.Close()
|
||||
}
|
||||
|
||||
if this.rawReader != nil {
|
||||
_ = this.rawReader.Close()
|
||||
}
|
||||
|
||||
// 结束缓存相关处理
|
||||
func (this *HTTPWriter) finishCache() {
|
||||
// 缓存
|
||||
if this.cacheWriter != nil {
|
||||
if this.isOk && this.cacheIsFinished {
|
||||
@@ -1054,7 +1117,10 @@ func (this *HTTPWriter) Close() {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 结束压缩相关处理
|
||||
func (this *HTTPWriter) finishCompression() {
|
||||
if this.compressionCacheWriter != nil {
|
||||
if this.isOk {
|
||||
err := this.compressionCacheWriter.Close()
|
||||
@@ -1075,59 +1141,15 @@ func (this *HTTPWriter) Close() {
|
||||
_ = this.compressionCacheWriter.Discard()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if this.sentBodyBytes == 0 {
|
||||
this.sentBodyBytes = this.counterWriter.TotalBytes()
|
||||
// 最终关闭
|
||||
func (this *HTTPWriter) finishRequest() {
|
||||
if this.writer != nil {
|
||||
_ = this.writer.Close()
|
||||
}
|
||||
|
||||
if this.rawReader != nil {
|
||||
_ = this.rawReader.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack Hijack
|
||||
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
hijack, ok := this.rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
return hijack.Hijack()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Flush Flush
|
||||
func (this *HTTPWriter) Flush() {
|
||||
flusher, ok := this.rawWriter.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// DelayRead 是否延迟读取Reader
|
||||
func (this *HTTPWriter) DelayRead() bool {
|
||||
return this.delayRead
|
||||
}
|
||||
|
||||
// 计算stale时长
|
||||
func (this *HTTPWriter) calculateStaleLife() int {
|
||||
var staleLife = 600 // TODO 可以在缓存策略里设置此时间
|
||||
var staleConfig = this.req.web.Cache.Stale
|
||||
if staleConfig != nil && staleConfig.IsOn {
|
||||
// 从Header中读取stale-if-error
|
||||
var isDefinedInHeader = false
|
||||
if staleConfig.SupportStaleIfErrorHeader {
|
||||
var cacheControl = this.GetHeader("Cache-Control")
|
||||
var pieces = strings.Split(cacheControl, ",")
|
||||
for _, piece := range pieces {
|
||||
var eqIndex = strings.Index(piece, "=")
|
||||
if eqIndex > 0 && strings.TrimSpace(piece[:eqIndex]) == "stale-if-error" {
|
||||
// 这里预示着如果stale-if-error=0,可以关闭stale功能
|
||||
staleLife = types.Int(strings.TrimSpace(piece[eqIndex+1:]))
|
||||
isDefinedInHeader = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义
|
||||
if !isDefinedInHeader && staleConfig.Life != nil {
|
||||
staleLife = types.Int(staleConfig.Life.Duration().Seconds())
|
||||
}
|
||||
}
|
||||
return staleLife
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
@@ -170,34 +171,8 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
|
||||
server, serverName := this.findNamedServer(domain)
|
||||
if server == nil {
|
||||
server = this.findServerWithCNAME(domain)
|
||||
if server == nil {
|
||||
// 严格匹配域名模式下,我们拒绝用户访问
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
httpAllConfig := sharedNodeConfig.GlobalConfig.HTTPAll
|
||||
mismatchAction := httpAllConfig.DomainMismatchAction
|
||||
if mismatchAction != nil && mismatchAction.Code == "page" {
|
||||
if mismatchAction.Options != nil {
|
||||
rawWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rawWriter.WriteHeader(mismatchAction.Options.GetInt("statusCode"))
|
||||
_, _ = rawWriter.Write([]byte(mismatchAction.Options.GetString("contentHTML")))
|
||||
} else {
|
||||
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
hijacker, ok := rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
|
||||
this.handleMismatch(rawReq, rawWriter)
|
||||
return
|
||||
} else {
|
||||
serverName = domain
|
||||
@@ -205,7 +180,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
}
|
||||
|
||||
// 包装新请求对象
|
||||
req := &HTTPRequest{
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: rawWriter,
|
||||
ReqServer: server,
|
||||
@@ -220,6 +195,48 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
req.Do()
|
||||
}
|
||||
|
||||
// 处理域名不匹配的情况
|
||||
func (this *HTTPListener) handleMismatch(rawReq *http.Request, rawWriter http.ResponseWriter) {
|
||||
// TODO 需要记录访问记录和防止CC
|
||||
|
||||
// 是否为健康检查
|
||||
var healthCheckKey = rawReq.Header.Get(serverconfigs.HealthCheckHeaderName)
|
||||
if len(healthCheckKey) > 0 {
|
||||
_, err := nodeutils.Base64DecodeMap(healthCheckKey)
|
||||
if err == nil {
|
||||
rawWriter.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 严格匹配域名模式下,我们拒绝用户访问
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
var httpAllConfig = sharedNodeConfig.GlobalConfig.HTTPAll
|
||||
var mismatchAction = httpAllConfig.DomainMismatchAction
|
||||
if mismatchAction != nil && mismatchAction.Code == "page" {
|
||||
if mismatchAction.Options != nil {
|
||||
rawWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rawWriter.WriteHeader(mismatchAction.Options.GetInt("statusCode"))
|
||||
_, _ = rawWriter.Write([]byte(mismatchAction.Options.GetString("contentHTML")))
|
||||
} else {
|
||||
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
hijacker, ok := rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (this *HTTPListener) isIP(host string) bool {
|
||||
// IPv6
|
||||
if strings.Index(host, "[") > -1 {
|
||||
|
||||
@@ -257,6 +257,12 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查状态
|
||||
err = exec.Command(firewallCmd, "--state").Run()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("FIREWALLD", "open ports automatically")
|
||||
for _, port := range ports {
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@ type TCPListener struct {
|
||||
}
|
||||
|
||||
func (this *TCPListener) Serve() error {
|
||||
listener := this.Listener
|
||||
var listener = this.Listener
|
||||
if this.Group.IsTLS() {
|
||||
listener = tls.NewListener(listener, this.buildTLSConfig())
|
||||
}
|
||||
@@ -52,14 +52,29 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
}
|
||||
|
||||
func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
firstServer := this.Group.FirstServer()
|
||||
if firstServer == nil {
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return errors.New("no server available")
|
||||
}
|
||||
if firstServer.ReverseProxy == nil {
|
||||
if server.ReverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server")
|
||||
}
|
||||
|
||||
// 是否已达到流量限制
|
||||
if this.reachedTrafficLimit() {
|
||||
// 关闭连接
|
||||
tcpConn, ok := conn.(LingerConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
}
|
||||
_ = conn.Close()
|
||||
|
||||
// TODO 使用系统防火墙drop当前端口的数据包一段时间(1分钟)
|
||||
// 不能使用阻止IP的方法,因为边缘节点只上有可能还有别的代理服务
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 记录域名排行
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
var recordStat = false
|
||||
@@ -67,17 +82,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
var serverName = tlsConn.ConnectionState().ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
recordStat = true
|
||||
}
|
||||
}
|
||||
|
||||
// 统计
|
||||
if !recordStat {
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -88,17 +103,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// PROXY Protocol
|
||||
if firstServer.ReverseProxy != nil &&
|
||||
firstServer.ReverseProxy.ProxyProtocol != nil &&
|
||||
firstServer.ReverseProxy.ProxyProtocol.IsOn &&
|
||||
(firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
|
||||
if server.ReverseProxy != nil &&
|
||||
server.ReverseProxy.ProxyProtocol != nil &&
|
||||
server.ReverseProxy.ProxyProtocol.IsOn &&
|
||||
(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
|
||||
var remoteAddr = conn.RemoteAddr()
|
||||
var transportProtocol = proxyproto.TCPv4
|
||||
if strings.Contains(remoteAddr.String(), "[") {
|
||||
transportProtocol = proxyproto.TCPv6
|
||||
}
|
||||
header := proxyproto.Header{
|
||||
Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version),
|
||||
var header = proxyproto.Header{
|
||||
Version: byte(server.ReverseProxy.ProxyProtocol.Version),
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: transportProtocol,
|
||||
SourceAddr: remoteAddr,
|
||||
@@ -113,7 +128,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
|
||||
// 从源站读取
|
||||
goman.New(func() {
|
||||
originBuffer := utils.BytePool16k.Get()
|
||||
var originBuffer = utils.BytePool16k.Get()
|
||||
defer func() {
|
||||
utils.BytePool16k.Put(originBuffer)
|
||||
}()
|
||||
@@ -127,8 +142,8 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
if firstServer != nil {
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -139,11 +154,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
})
|
||||
|
||||
// 从客户端读取
|
||||
clientBuffer := utils.BytePool16k.Get()
|
||||
var clientBuffer = utils.BytePool16k.Get()
|
||||
defer func() {
|
||||
utils.BytePool16k.Put(clientBuffer)
|
||||
}()
|
||||
for {
|
||||
// 是否已达到流量限制
|
||||
if this.reachedTrafficLimit() {
|
||||
closer()
|
||||
return nil
|
||||
}
|
||||
|
||||
n, err := conn.Read(clientBuffer)
|
||||
if n > 0 {
|
||||
_, err = originConn.Write(clientBuffer[:n])
|
||||
@@ -188,3 +209,12 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
|
||||
err = errors.New("no origin can be used")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已经达到流量限制
|
||||
func (this *TCPListener) reachedTrafficLimit() bool {
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return true
|
||||
}
|
||||
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
@@ -15,12 +16,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/trackers"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
@@ -115,6 +116,7 @@ func (this *Node) Start() {
|
||||
this.checkDisk()
|
||||
|
||||
// 读取API配置
|
||||
remotelogs.Println("NODE", "init config ...")
|
||||
err = this.syncConfig(0)
|
||||
if err != nil {
|
||||
_, err := nodeconfigs.SharedNodeConfig()
|
||||
@@ -368,6 +370,38 @@ func (this *Node) loop() error {
|
||||
}
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "ddosProtectionChanged":
|
||||
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDOSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
@@ -396,7 +430,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
clusterErr := this.checkClusterConfig()
|
||||
if clusterErr != nil {
|
||||
if os.IsNotExist(clusterErr) {
|
||||
return err
|
||||
return errors.New("can not find config file 'configs/api.yaml'")
|
||||
}
|
||||
return errors.New("check cluster config failed: " + clusterErr.Error())
|
||||
}
|
||||
@@ -426,7 +460,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
configJSON := configResp.NodeJSON
|
||||
var configJSON = configResp.NodeJSON
|
||||
if configResp.IsCompressed {
|
||||
var reader = brotli.NewReader(bytes.NewReader(configJSON))
|
||||
var configBuf = &bytes.Buffer{}
|
||||
@@ -445,7 +479,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
|
||||
nodeConfigUpdatedAt = time.Now().Unix()
|
||||
|
||||
nodeConfig := &nodeconfigs.NodeConfig{}
|
||||
var nodeConfig = &nodeconfigs.NodeConfig{}
|
||||
err = json.Unmarshal(configJSON, nodeConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode config failed: " + err.Error())
|
||||
@@ -453,6 +487,15 @@ func (this *Node) syncConfig(taskVersion int64) error {
|
||||
teaconst.NodeId = nodeConfig.Id
|
||||
teaconst.NodeIdString = types.String(teaconst.NodeId)
|
||||
|
||||
// 检查时间是否一致
|
||||
// 这个需要在 teaconst.NodeId 设置之后,因为上报到API节点的时候需要节点ID
|
||||
if configResp.Timestamp > 0 {
|
||||
var timestampDelta = configResp.Timestamp - time.Now().Unix()
|
||||
if timestampDelta > 60 || timestampDelta < -60 {
|
||||
remotelogs.Error("NODE", "node timestamp ('"+types.String(time.Now().Unix())+"') is not same as api node ('"+types.String(configResp.Timestamp)+"'), please sync the time")
|
||||
}
|
||||
}
|
||||
|
||||
// 写入到文件中
|
||||
err = nodeConfig.Save()
|
||||
if err != nil {
|
||||
@@ -721,7 +764,6 @@ func (this *Node) listenSock() error {
|
||||
"ipConns": ipConns,
|
||||
"serverConns": serverConns,
|
||||
"total": sharedListenerManager.TotalActiveConnections(),
|
||||
"limiter": sharedConnectionsLimiter.Len(),
|
||||
},
|
||||
})
|
||||
case "dropIP":
|
||||
@@ -780,6 +822,18 @@ func (this *Node) listenSock() error {
|
||||
} else {
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
case "accesslog":
|
||||
err := sharedHTTPAccessLogViewer.Start()
|
||||
if err != nil {
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "error",
|
||||
Params: map[string]interface{}{
|
||||
"message": "start failed: " + err.Error(),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -813,7 +867,7 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
|
||||
}
|
||||
|
||||
// WAF策略
|
||||
sharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
|
||||
waf.SharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies())
|
||||
iplibrary.SharedActionManager.UpdateActions(config.FirewallActions)
|
||||
|
||||
// 统计指标
|
||||
@@ -845,17 +899,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
|
||||
this.maxThreads = config.MaxThreads
|
||||
}
|
||||
|
||||
// max tcp connections
|
||||
if config.TCPMaxConnections <= 0 {
|
||||
config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
|
||||
}
|
||||
if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
|
||||
remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
|
||||
|
||||
sharedConnectionsLimiter.Close()
|
||||
sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
|
||||
}
|
||||
|
||||
// timezone
|
||||
var timeZone = config.TimeZone
|
||||
if len(timeZone) == 0 {
|
||||
@@ -878,6 +921,29 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
|
||||
if config.ProductConfig != nil {
|
||||
teaconst.GlobalProductName = config.ProductConfig.Name
|
||||
}
|
||||
|
||||
// DNS resolver
|
||||
if config.DNSResolver != nil {
|
||||
var err error
|
||||
switch config.DNSResolver.Type {
|
||||
case nodeconfigs.DNSResolverTypeGoNative:
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
case nodeconfigs.DNSResolverTypeCGO:
|
||||
err = os.Setenv("GODEBUG", "netdns=cgo")
|
||||
default:
|
||||
// 默认使用go原生
|
||||
err = os.Setenv("GODEBUG", "netdns=go")
|
||||
}
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
// 默认使用go原生
|
||||
err := os.Setenv("GODEBUG", "netdns=go")
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reload server config
|
||||
|
||||
@@ -109,6 +109,7 @@ func (this *NodeStatusExecutor) update() {
|
||||
cacheSpaceTR.End()
|
||||
|
||||
status.UpdatedAt = time.Now().Unix()
|
||||
status.Timestamp = status.UpdatedAt
|
||||
|
||||
// 发送数据
|
||||
jsonData, err := json.Marshal(status)
|
||||
|
||||
@@ -45,6 +45,8 @@ func NewOriginStateManager() *OriginStateManager {
|
||||
// Start 启动
|
||||
func (this *OriginStateManager) Start() {
|
||||
events.OnKey(events.EventReload, this, func() {
|
||||
// TODO 检查源站是否有变化
|
||||
|
||||
this.locker.Lock()
|
||||
this.stateMap = map[int64]*OriginState{}
|
||||
this.locker.Unlock()
|
||||
@@ -143,7 +145,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse
|
||||
state.UpdatedAt = timestamp
|
||||
|
||||
if origin.IsOk {
|
||||
origin.IsOk = state.CountFails > 5 // 超过 N 次之后认为是异常
|
||||
origin.IsOk = state.CountFails < 5 // 超过 N 次之后认为是异常
|
||||
|
||||
if !origin.IsOk {
|
||||
if callback != nil {
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -64,7 +63,7 @@ func (this *UpgradeManager) Start() {
|
||||
goman.New(func() {
|
||||
err = this.restart()
|
||||
if err != nil {
|
||||
logs.Println("UPGRADE_MANAGER", err.Error())
|
||||
remotelogs.Error("UPGRADE_MANAGER", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -67,6 +67,10 @@ func (this *RPCClient) HTTPAccessLogRPC() pb.HTTPAccessLogServiceClient {
|
||||
return pb.NewHTTPAccessLogServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) HTTPCacheTaskKeyRPC() pb.HTTPCacheTaskKeyServiceClient {
|
||||
return pb.NewHTTPCacheTaskKeyServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) APINodeRPC() pb.APINodeServiceClient {
|
||||
return pb.NewAPINodeServiceClient(this.pickConn())
|
||||
}
|
||||
@@ -115,18 +119,6 @@ func (this *RPCClient) ServerRPC() pb.ServerServiceClient {
|
||||
return pb.NewServerServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) ServerRPCList() []pb.ServerServiceClient {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var clients = []pb.ServerServiceClient{}
|
||||
for _, conn := range this.conns {
|
||||
clients = append(clients, pb.NewServerServiceClient(conn))
|
||||
}
|
||||
|
||||
return clients
|
||||
}
|
||||
|
||||
func (this *RPCClient) ServerDailyStatRPC() pb.ServerDailyStatServiceClient {
|
||||
return pb.NewServerDailyStatServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig)
|
||||
remotelogs.Println("TRAFFIC_STAT_MANAGER", "quit")
|
||||
ticker.Stop()
|
||||
})
|
||||
remotelogs.Println("TRAFFIC_STA_MANAGER", "start ...")
|
||||
remotelogs.Println("TRAFFIC_STAT_MANAGER", "start ...")
|
||||
for range ticker.C {
|
||||
err := this.Upload()
|
||||
if err != nil {
|
||||
|
||||
@@ -91,7 +91,7 @@ func (this *Cache) Write(key string, value interface{}, expiredAt int64) (ok boo
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64) int64 {
|
||||
func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64, extend bool) int64 {
|
||||
if this.isDestroyed {
|
||||
return 0
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64) int64
|
||||
}
|
||||
uint64Key := HashKey([]byte(key))
|
||||
pieceIndex := uint64Key % this.countPieces
|
||||
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt)
|
||||
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt, extend)
|
||||
}
|
||||
|
||||
func (this *Cache) Read(key string) (item *Item) {
|
||||
|
||||
@@ -65,14 +65,14 @@ func TestCache_IncreaseInt64(t *testing.T) {
|
||||
var unixTime = time.Now().Unix()
|
||||
|
||||
{
|
||||
cache.IncreaseInt64("a", 1, unixTime+3600)
|
||||
cache.IncreaseInt64("a", 1, unixTime+3600, false)
|
||||
var item = cache.Read("a")
|
||||
t.Log(item)
|
||||
a.IsTrue(item.Value == int64(1))
|
||||
a.IsTrue(item.expiredAt == unixTime+3600)
|
||||
}
|
||||
{
|
||||
cache.IncreaseInt64("a", 1, unixTime+3600+1)
|
||||
cache.IncreaseInt64("a", 1, unixTime+3600+1, true)
|
||||
var item = cache.Read("a")
|
||||
t.Log(item)
|
||||
a.IsTrue(item.Value == int64(2))
|
||||
@@ -83,7 +83,7 @@ func TestCache_IncreaseInt64(t *testing.T) {
|
||||
t.Log(cache.Read("b"))
|
||||
}
|
||||
{
|
||||
cache.IncreaseInt64("b", 1, time.Now().Unix()+3600+3)
|
||||
cache.IncreaseInt64("b", 1, time.Now().Unix()+3600+3, false)
|
||||
t.Log(cache.Read("b"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,13 +39,15 @@ func (this *Piece) Add(key uint64, item *Item) (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64) (result int64) {
|
||||
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64, extend bool) (result int64) {
|
||||
this.locker.Lock()
|
||||
item, ok := this.m[key]
|
||||
if ok && item.expiredAt > time.Now().Unix() {
|
||||
result = types.Int64(item.Value) + delta
|
||||
item.Value = result
|
||||
item.expiredAt = expiredAt
|
||||
if extend {
|
||||
item.expiredAt = expiredAt
|
||||
}
|
||||
this.expiresList.Add(key, expiredAt)
|
||||
} else {
|
||||
if len(this.m) < this.maxItems {
|
||||
|
||||
@@ -5,9 +5,12 @@ import (
|
||||
"github.com/cespare/xxhash"
|
||||
"math"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ipv4Reg = regexp.MustCompile(`\d+\.`)
|
||||
|
||||
// IP2Long 将IP转换为整型
|
||||
// 注意IPv6没有顺序
|
||||
func IP2Long(ip string) uint64 {
|
||||
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsIPv4 是否为IPv4
|
||||
func IsIPv4(ip string) bool {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(ip, ":") {
|
||||
return false
|
||||
}
|
||||
return data.To4() != nil
|
||||
}
|
||||
|
||||
// IsIPv6 是否为IPv6
|
||||
func IsIPv6(ip string) bool {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
return !IsIPv4(ip)
|
||||
}
|
||||
|
||||
@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
|
||||
a.IsFalse(IsLocalIP("::1:2:3"))
|
||||
a.IsFalse(IsLocalIP("8.8.8.8"))
|
||||
}
|
||||
|
||||
func TestIsIPv4(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsTrue(IsIPv4("192.168.1.1"))
|
||||
a.IsTrue(IsIPv4("0.0.0.0"))
|
||||
a.IsFalse(IsIPv4("192.168.1.256"))
|
||||
a.IsFalse(IsIPv4("192.168.1"))
|
||||
a.IsFalse(IsIPv4("::1"))
|
||||
a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||
a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
|
||||
}
|
||||
|
||||
func TestIsIPv6(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(IsIPv6("192.168.1.1"))
|
||||
a.IsFloat32(IsIPv6("0.0.0.0"))
|
||||
a.IsFalse(IsIPv6("192.168.1.256"))
|
||||
a.IsFalse(IsIPv6("192.168.1"))
|
||||
a.IsTrue(IsIPv6("::1"))
|
||||
a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||
a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
|
||||
a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
package readers
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type TeeReaderCloser struct {
|
||||
r io.Reader
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ToValidUTF8string 去除字符串中的非UTF-8字符
|
||||
func ToValidUTF8string(v string) string {
|
||||
return strings.ToValidUTF8(v, "")
|
||||
}
|
||||
|
||||
// ContainsSameStrings 检查两个字符串slice内容是否一致
|
||||
func ContainsSameStrings(s1 []string, s2 []string) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
sort.Strings(s1)
|
||||
sort.Strings(s2)
|
||||
for index, v1 := range s1 {
|
||||
if v1 != s2[index] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,56 +1,67 @@
|
||||
package utils
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
t.Log(UnsafeBytesToString([]byte("Hello,World")))
|
||||
t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
|
||||
}
|
||||
|
||||
func TestStringToBytes(t *testing.T) {
|
||||
t.Log(string(UnsafeStringToBytes("Hello,World")))
|
||||
t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString(b *testing.B) {
|
||||
data := []byte("Hello,World")
|
||||
var data = []byte("Hello,World")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = UnsafeBytesToString(data)
|
||||
_ = utils.UnsafeBytesToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString2(b *testing.B) {
|
||||
data := []byte("Hello,World")
|
||||
var data = []byte("Hello,World")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToBytes(b *testing.B) {
|
||||
s := strings.Repeat("Hello,World", 1024)
|
||||
var s = strings.Repeat("Hello,World", 1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = UnsafeStringToBytes(s)
|
||||
_ = utils.UnsafeStringToBytes(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToBytes2(b *testing.B) {
|
||||
s := strings.Repeat("Hello,World", 1024)
|
||||
var s = strings.Repeat("Hello,World", 1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = []byte(s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatAddress(t *testing.T) {
|
||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(FormatAddress("127.0.0.1 : 1234"))
|
||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||
}
|
||||
|
||||
func TestFormatAddressList(t *testing.T) {
|
||||
t.Log(FormatAddressList([]string{
|
||||
t.Log(utils.FormatAddressList([]string{
|
||||
"127.0.0.1:1234",
|
||||
"127.0.0.1 : 1234",
|
||||
"127.0.0.1:1234",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestContainsSameStrings(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
|
||||
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
|
||||
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type AllowAction struct {
|
||||
BaseAction
|
||||
}
|
||||
|
||||
func (this *AllowAction) Init(waf *WAF) error {
|
||||
|
||||
@@ -7,6 +7,17 @@ import (
|
||||
)
|
||||
|
||||
type BaseAction struct {
|
||||
currentActionId int64
|
||||
}
|
||||
|
||||
// ActionId 读取ActionId
|
||||
func (this *BaseAction) ActionId() int64 {
|
||||
return this.currentActionId
|
||||
}
|
||||
|
||||
// SetActionId 设置Id
|
||||
func (this *BaseAction) SetActionId(actionId int64) {
|
||||
this.currentActionId = actionId
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
|
||||
@@ -20,6 +20,8 @@ var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
|
||||
var httpClient = utils.SharedHttpClient(5 * time.Second)
|
||||
|
||||
type BlockAction struct {
|
||||
BaseAction
|
||||
|
||||
StatusCode int `yaml:"statusCode" json:"statusCode"`
|
||||
Body string `yaml:"body" json:"body"` // supports HTML
|
||||
URL string `yaml:"url" json:"url"`
|
||||
|
||||
@@ -18,16 +18,71 @@ const (
|
||||
)
|
||||
|
||||
type CaptchaAction struct {
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数
|
||||
FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间
|
||||
BaseAction
|
||||
|
||||
Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ...
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数
|
||||
FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间
|
||||
FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"` // 是否全局有效
|
||||
|
||||
CountLetters int8 `yaml:"countLetters" json:"countLetters"`
|
||||
|
||||
UIIsOn bool `yaml:"uiIsOn" json:"uiIsOn"` // 是否使用自定义UI
|
||||
UITitle string `yaml:"uiTitle" json:"uiTitle"` // 消息标题
|
||||
UIPrompt string `yaml:"uiPrompt" json:"uiPrompt"` // 消息提示
|
||||
UIButtonTitle string `yaml:"uiButtonTitle" json:"uiButtonTitle"` // 按钮标题
|
||||
UIShowRequestId bool `yaml:"uiShowRequestId" json:"uiShowRequestId"` // 是否显示请求ID
|
||||
UICss string `yaml:"uiCss" json:"uiCss"` // CSS样式
|
||||
UIFooter string `yaml:"uiFooter" json:"uiFooter"` // 页脚
|
||||
UIBody string `yaml:"uiBody" json:"uiBody"` // 内容轮廓
|
||||
|
||||
Lang string `yaml:"lang" json:"lang"` // 语言,zh-CN, en-US ...
|
||||
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
|
||||
Scope string `yaml:"scope" json:"scope"`
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Init(waf *WAF) error {
|
||||
if waf.DefaultCaptchaAction != nil {
|
||||
if this.Life <= 0 {
|
||||
this.Life = waf.DefaultCaptchaAction.Life
|
||||
}
|
||||
if this.MaxFails <= 0 {
|
||||
this.MaxFails = waf.DefaultCaptchaAction.MaxFails
|
||||
}
|
||||
if this.FailBlockTimeout <= 0 {
|
||||
this.FailBlockTimeout = waf.DefaultCaptchaAction.FailBlockTimeout
|
||||
}
|
||||
this.FailBlockScopeAll = waf.DefaultCaptchaAction.FailBlockScopeAll
|
||||
|
||||
if this.CountLetters <= 0 {
|
||||
this.CountLetters = waf.DefaultCaptchaAction.CountLetters
|
||||
}
|
||||
|
||||
this.UIIsOn = waf.DefaultCaptchaAction.UIIsOn
|
||||
if len(this.UITitle) == 0 {
|
||||
this.UITitle = waf.DefaultCaptchaAction.UITitle
|
||||
}
|
||||
if len(this.UIPrompt) == 0 {
|
||||
this.UIPrompt = waf.DefaultCaptchaAction.UIPrompt
|
||||
}
|
||||
if len(this.UIButtonTitle) == 0 {
|
||||
this.UIButtonTitle = waf.DefaultCaptchaAction.UIButtonTitle
|
||||
}
|
||||
this.UIShowRequestId = waf.DefaultCaptchaAction.UIShowRequestId
|
||||
if len(this.UICss) == 0 {
|
||||
this.UICss = waf.DefaultCaptchaAction.UICss
|
||||
}
|
||||
if len(this.UIFooter) == 0 {
|
||||
this.UIFooter = waf.DefaultCaptchaAction.UIFooter
|
||||
}
|
||||
if len(this.UIBody) == 0 {
|
||||
this.UIBody = waf.DefaultCaptchaAction.UIBody
|
||||
}
|
||||
if len(this.Lang) == 0 {
|
||||
this.Lang = waf.DefaultCaptchaAction.Lang
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -43,17 +98,17 @@ func (this *CaptchaAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
|
||||
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
refURL := request.WAFRaw().URL.String()
|
||||
var refURL = req.WAFRaw().URL.String()
|
||||
|
||||
// 覆盖配置
|
||||
if strings.HasPrefix(refURL, CaptchaPath) {
|
||||
info := request.WAFRaw().URL.Query().Get("info")
|
||||
info := req.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) > 0 {
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err == nil && m != nil {
|
||||
@@ -63,14 +118,12 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
}
|
||||
|
||||
var captchaConfig = maps.Map{
|
||||
"action": this,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"maxFails": this.MaxFails,
|
||||
"failBlockTimeout": this.FailBlockTimeout,
|
||||
"url": refURL,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
"actionId": this.ActionId(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"url": refURL,
|
||||
"policyId": waf.Id,
|
||||
"groupId": group.Id,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
if err != nil {
|
||||
@@ -78,7 +131,10 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
|
||||
return true
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
|
||||
// 占用一次失败次数
|
||||
CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit)
|
||||
|
||||
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type GoGroupAction struct {
|
||||
BaseAction
|
||||
|
||||
GroupId string `yaml:"groupId" json:"groupId"`
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type GoSetAction struct {
|
||||
BaseAction
|
||||
|
||||
GroupId string `yaml:"groupId" json:"groupId"`
|
||||
SetId string `yaml:"setId" json:"setId"`
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ type ActionInterface interface {
|
||||
// Init 初始化
|
||||
Init(waf *WAF) error
|
||||
|
||||
// ActionId 读取ActionId
|
||||
ActionId() int64
|
||||
|
||||
// SetActionId 设置ID
|
||||
SetActionId(id int64)
|
||||
|
||||
// Code 代号
|
||||
Code() string
|
||||
|
||||
@@ -20,6 +26,6 @@ type ActionInterface interface {
|
||||
// WillChange determine if the action will change the request
|
||||
WillChange() bool
|
||||
|
||||
// Perform perform the action
|
||||
// Perform the action
|
||||
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type LogAction struct {
|
||||
BaseAction
|
||||
}
|
||||
|
||||
func (this *LogAction) Init(waf *WAF) error {
|
||||
|
||||
@@ -50,6 +50,7 @@ func init() {
|
||||
}
|
||||
|
||||
type NotifyAction struct {
|
||||
BaseAction
|
||||
}
|
||||
|
||||
func (this *NotifyAction) Init(waf *WAF) error {
|
||||
@@ -69,7 +70,7 @@ func (this *NotifyAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Perform perform the action
|
||||
// Perform the action
|
||||
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
select {
|
||||
case notifyChan <- ¬ifyTask{
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
)
|
||||
|
||||
type TagAction struct {
|
||||
BaseAction
|
||||
|
||||
Tags []string `yaml:"tags" json:"tags"`
|
||||
}
|
||||
|
||||
|
||||
@@ -5,15 +5,19 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var seedActionId int64 = 1
|
||||
|
||||
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||
for _, def := range AllActions {
|
||||
if def.Code == action {
|
||||
if def.Type != nil {
|
||||
// create new instance
|
||||
ptrValue := reflect.New(def.Type)
|
||||
instance := ptrValue.Interface().(ActionInterface)
|
||||
var ptrValue = reflect.New(def.Type)
|
||||
var instance = ptrValue.Interface().(ActionInterface)
|
||||
instance.SetActionId(atomic.AddInt64(&seedActionId, 1))
|
||||
|
||||
if len(options) > 0 {
|
||||
optionsJSON, err := json.Marshal(options)
|
||||
|
||||
54
internal/waf/captcha_counter.go
Normal file
54
internal/waf/captcha_counter.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CaptchaPageCode = string
|
||||
|
||||
const (
|
||||
CaptchaPageCodeInit CaptchaPageCode = "init"
|
||||
CaptchaPageCodeShow CaptchaPageCode = "show"
|
||||
CaptchaPageCodeSubmit CaptchaPageCode = "submit"
|
||||
)
|
||||
|
||||
// CaptchaIncreaseFails 增加Captcha失败次数,以便后续操作
|
||||
func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, pageCode CaptchaPageCode) (goNext bool) {
|
||||
var maxFails = actionConfig.MaxFails
|
||||
var failBlockTimeout = actionConfig.FailBlockTimeout
|
||||
if maxFails > 0 && failBlockTimeout > 0 {
|
||||
if maxFails <= 3 {
|
||||
maxFails = 3 // 不能小于3,防止意外刷新出现
|
||||
}
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true)
|
||||
if int(countFails) >= maxFails {
|
||||
var useLocalFirewall = false
|
||||
|
||||
if actionConfig.FailBlockScopeAll {
|
||||
useLocalFirewall = true
|
||||
}
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败")
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CaptchaDeleteCacheKey 清除计数
|
||||
func CaptchaDeleteCacheKey(req requests.Request) {
|
||||
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeInit))
|
||||
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeShow))
|
||||
ttlcache.SharedCache.Delete(CaptchaCacheKey(req, CaptchaPageCodeSubmit))
|
||||
}
|
||||
|
||||
// CaptchaCacheKey 获取Captcha缓存Key
|
||||
func CaptchaCacheKey(req requests.Request, pageCode CaptchaPageCode) string {
|
||||
return "CAPTCHA:FAILS:" + pageCode + ":" + req.WAFRemoteIP() + ":" + types.String(req.WAFServerId())
|
||||
}
|
||||
@@ -3,10 +3,7 @@ package waf
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/dchest/captcha"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
@@ -26,8 +23,8 @@ func NewCaptchaValidator() *CaptchaValidator {
|
||||
return &CaptchaValidator{}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
|
||||
var info = request.WAFRaw().URL.Query().Get("info")
|
||||
func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter) {
|
||||
var info = req.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
@@ -39,42 +36,60 @@ func (this *CaptchaValidator) Run(request requests.Request, writer http.Response
|
||||
return
|
||||
}
|
||||
|
||||
timestamp := m.GetInt64("timestamp")
|
||||
var timestamp = m.GetInt64("timestamp")
|
||||
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var actionConfig = &CaptchaAction{}
|
||||
err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
|
||||
if err != nil {
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
http.Redirect(writer, req.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var actionId = m.GetInt64("actionId")
|
||||
var setId = m.GetInt64("setId")
|
||||
var originURL = m.GetString("url")
|
||||
var policyId = m.GetInt64("policyId")
|
||||
var groupId = m.GetInt64("groupId")
|
||||
|
||||
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(actionConfig, m.GetInt("maxFails"), m.GetInt("failBlockTimeout"), m.GetInt64("policyId"), m.GetInt64("groupId"), setId, originURL, request, writer)
|
||||
var waf = SharedWAFManager.FindWAF(policyId)
|
||||
if waf == nil {
|
||||
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
var actionConfig = waf.FindAction(actionId)
|
||||
if actionConfig == nil {
|
||||
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
captchaActionConfig, ok := actionConfig.(*CaptchaAction)
|
||||
if !ok {
|
||||
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
if req.WAFRaw().Method == http.MethodPost && len(req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(captchaActionConfig, policyId, groupId, setId, originURL, req, writer)
|
||||
} else {
|
||||
this.show(actionConfig, request, writer)
|
||||
// 增加计数
|
||||
CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeShow)
|
||||
this.show(captchaActionConfig, req, writer)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
|
||||
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Request, writer http.ResponseWriter) {
|
||||
// show captcha
|
||||
captchaId := captcha.NewLen(6)
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
var countLetters = 6
|
||||
if actionConfig.CountLetters > 0 && actionConfig.CountLetters <= 10 {
|
||||
countLetters = int(actionConfig.CountLetters)
|
||||
}
|
||||
var captchaId = captcha.NewLen(countLetters)
|
||||
var buf = bytes.NewBuffer([]byte{})
|
||||
err := captcha.WriteImage(buf, captchaId, 200, 100)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
var lang = actionConfig.Language
|
||||
var lang = actionConfig.Lang
|
||||
if len(lang) == 0 {
|
||||
acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
|
||||
var acceptLanguage = req.WAFRaw().Header.Get("Accept-Language")
|
||||
if len(acceptLanguage) > 0 {
|
||||
langIndex := strings.Index(acceptLanguage, ",")
|
||||
if langIndex > 0 {
|
||||
@@ -109,12 +124,67 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
|
||||
msgRequestId = "Request ID"
|
||||
}
|
||||
|
||||
var msgCss = ""
|
||||
var requestIdBox = `<address>` + msgRequestId + `: ` + req.Format("${requestId}") + `</address>`
|
||||
var msgFooter = ""
|
||||
|
||||
// 默认设置
|
||||
if actionConfig.UIIsOn {
|
||||
if len(actionConfig.UIPrompt) > 0 {
|
||||
msgPrompt = actionConfig.UIPrompt
|
||||
}
|
||||
if len(actionConfig.UIButtonTitle) > 0 {
|
||||
msgButtonTitle = actionConfig.UIButtonTitle
|
||||
}
|
||||
if len(actionConfig.UITitle) > 0 {
|
||||
msgTitle = actionConfig.UITitle
|
||||
}
|
||||
if len(actionConfig.UICss) > 0 {
|
||||
msgCss = actionConfig.UICss
|
||||
}
|
||||
if !actionConfig.UIShowRequestId {
|
||||
requestIdBox = ""
|
||||
}
|
||||
if len(actionConfig.UIFooter) > 0 {
|
||||
msgFooter = actionConfig.UIFooter
|
||||
}
|
||||
}
|
||||
|
||||
var body = `<form method="POST">
|
||||
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<div class="ui-image">
|
||||
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
|
||||
</div>
|
||||
<div class="ui-input">
|
||||
<p>` + msgPrompt + `</p>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" class="input"/>
|
||||
</div>
|
||||
<div class="ui-button">
|
||||
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
|
||||
</div>
|
||||
</form>
|
||||
` + requestIdBox + `
|
||||
` + msgFooter + ``
|
||||
|
||||
// Body
|
||||
if actionConfig.UIIsOn {
|
||||
if len(actionConfig.UIBody) > 0 {
|
||||
var index = strings.Index(actionConfig.UIBody, "${body}")
|
||||
if index < 0 {
|
||||
body = actionConfig.UIBody + body
|
||||
} else {
|
||||
body = actionConfig.UIBody[:index] + body + actionConfig.UIBody[index+7:] // 7是"${body}"的长度
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = writer.Write([]byte(`<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>` + msgTitle + `</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
|
||||
<meta charset="UTF-8"/>
|
||||
<script type="text/javascript">
|
||||
if (window.addEventListener != null) {
|
||||
window.addEventListener("load", function () {
|
||||
@@ -126,32 +196,22 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests
|
||||
form { width: 20em; margin: 0 auto; text-align: center; }
|
||||
.input { font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 10px; width: 140px; }
|
||||
address { margin-top: 1em; padding-top: 0.5em; border-top: 1px #ccc solid; text-align: center; }
|
||||
` + msgCss + `
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST">
|
||||
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
|
||||
<div>
|
||||
<p>` + msgPrompt + `</p>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" class="input"/>
|
||||
</div>
|
||||
<div>
|
||||
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
|
||||
</div>
|
||||
</form>
|
||||
<address>` + msgRequestId + `: ` + request.Format("${requestId}") + `</address>
|
||||
<body>` + body + `
|
||||
</body>
|
||||
</html>`))
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, maxFails int, failBlockTimeout int, policyId int64, groupId int64, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
|
||||
var captchaId = req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
if len(captchaId) > 0 {
|
||||
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
var captchaCode = req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
if captcha.VerifyString(captchaId, captchaCode) {
|
||||
// 删除计数
|
||||
ttlcache.SharedCache.Delete("CAPTCHA:FAILS:" + request.WAFRemoteIP())
|
||||
// 清除计数
|
||||
CaptchaDeleteCacheKey(req)
|
||||
|
||||
var life = CaptchaSeconds
|
||||
if actionConfig.Life > 0 {
|
||||
@@ -159,22 +219,18 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, maxFails int
|
||||
}
|
||||
|
||||
// 加入到白名单
|
||||
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "")
|
||||
SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "")
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
http.Redirect(writer, req.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
|
||||
return false
|
||||
} else {
|
||||
// 增加计数
|
||||
if maxFails > 0 && failBlockTimeout > 0 {
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64("CAPTCHA:FAILS:"+request.WAFRemoteIP(), 1, time.Now().Unix()+300)
|
||||
if int(countFails) >= maxFails {
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, false, groupId, setId, "CAPTCHA验证连续失败")
|
||||
return false
|
||||
}
|
||||
if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) {
|
||||
return false
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
|
||||
http.Redirect(writer, req.WAFRaw(), req.WAFRaw().URL.String(), http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ func (this *CCCheckpoint) RequestValue(req requests.Request, param string, optio
|
||||
if len(key) == 0 {
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
|
||||
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period, false)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
@@ -38,7 +38,7 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
|
||||
threshold = 1000
|
||||
}
|
||||
|
||||
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period)
|
||||
value = ccCache.IncreaseInt64("WAF-CC-"+strings.Join(keyValues, "@"), 1, time.Now().Unix()+period, false)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
21
internal/waf/checkpoints/request_url.go
Normal file
21
internal/waf/checkpoints/request_url.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package checkpoints
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
type RequestURLCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestURLCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.Format("${requestURL}"), nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestURLCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -40,17 +40,24 @@ var AllCheckpoints = []*CheckpointDefinition{
|
||||
{
|
||||
Name: "请求URI",
|
||||
Prefix: "requestURI",
|
||||
Description: "包含URL参数的请求URI,比如/hello/world?lang=go",
|
||||
Description: "包含URL参数的请求URI,类似于 /hello/world?lang=go",
|
||||
HasParams: false,
|
||||
Instance: new(RequestURICheckpoint),
|
||||
},
|
||||
{
|
||||
Name: "请求路径",
|
||||
Prefix: "requestPath",
|
||||
Description: "不包含URL参数的请求路径,比如/hello/world",
|
||||
Description: "不包含URL参数的请求路径,类似于 /hello/world",
|
||||
HasParams: false,
|
||||
Instance: new(RequestPathCheckpoint),
|
||||
},
|
||||
{
|
||||
Name: "请求URL",
|
||||
Prefix: "requestURL",
|
||||
Description: "完整的请求URL,包含协议、域名、请求路径、参数等,类似于 https://example.com/hello?name=lily",
|
||||
HasParams: false,
|
||||
Instance: new(RequestURLCheckpoint),
|
||||
},
|
||||
{
|
||||
Name: "请求内容长度",
|
||||
Prefix: "requestLength",
|
||||
|
||||
@@ -130,8 +130,8 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
|
||||
}
|
||||
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.ipMap[ip]
|
||||
this.locker.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ func (this *RuleSet) Init(waf *WAF) error {
|
||||
// action instances
|
||||
this.actionInstances = []ActionInterface{}
|
||||
for _, action := range this.Actions {
|
||||
instance := FindActionInstance(action.Code, action.Options)
|
||||
var instance = FindActionInstance(action.Code, action.Options)
|
||||
if instance == nil {
|
||||
remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'")
|
||||
continue
|
||||
@@ -79,6 +79,7 @@ func (this *RuleSet) Init(waf *WAF) error {
|
||||
}
|
||||
|
||||
this.actionInstances = append(this.actionInstances, instance)
|
||||
waf.AddAction(instance)
|
||||
}
|
||||
|
||||
// sort actions
|
||||
|
||||
@@ -26,12 +26,14 @@ type WAF struct {
|
||||
UseLocalFirewall bool `yaml:"useLocalFirewall" json:"useLocalFirewall"`
|
||||
SYNFlood *firewallconfigs.SYNFloodConfig `yaml:"synFlood" json:"synFlood"`
|
||||
|
||||
DefaultBlockAction *BlockAction
|
||||
DefaultBlockAction *BlockAction
|
||||
DefaultCaptchaAction *CaptchaAction
|
||||
|
||||
hasInboundRules bool
|
||||
hasOutboundRules bool
|
||||
|
||||
checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
|
||||
actionMap map[int64]ActionInterface // actionId => ActionInterface
|
||||
}
|
||||
|
||||
func NewWAF() *WAF {
|
||||
@@ -74,6 +76,9 @@ func (this *WAF) Init() (resultErrors []error) {
|
||||
this.checkpointsMap[def.Prefix] = instance
|
||||
}
|
||||
|
||||
// action map
|
||||
this.actionMap = map[int64]ActionInterface{}
|
||||
|
||||
// rules
|
||||
this.hasInboundRules = len(this.Inbound) > 0
|
||||
this.hasOutboundRules = len(this.Outbound) > 0
|
||||
@@ -324,8 +329,16 @@ func (this *WAF) ContainsGroupCode(code string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *WAF) AddAction(action ActionInterface) {
|
||||
this.actionMap[action.ActionId()] = action
|
||||
}
|
||||
|
||||
func (this *WAF) FindAction(actionId int64) ActionInterface {
|
||||
return this.actionMap[actionId]
|
||||
}
|
||||
|
||||
func (this *WAF) Copy() *WAF {
|
||||
waf := &WAF{
|
||||
var waf = &WAF{
|
||||
Id: this.Id,
|
||||
IsOn: this.IsOn,
|
||||
Name: this.Name,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user