Compare commits

...

15 Commits

Author SHA1 Message Date
刘祥超
ecef94b700 增加简化版的创建TCP网站API 2023-06-18 17:14:24 +08:00
刘祥超
a493bbb280 增加简化版的创建HTTP网站API 2023-06-18 16:20:00 +08:00
刘祥超
eee902abec 优化错误提示 2023-06-16 08:17:00 +08:00
刘祥超
2aceb4fb4d 版本号改为1.2.0 2023-06-12 14:42:26 +08:00
刘祥超
c1bbcc8dab 已停用的节点不计算在离线节点里 2023-06-12 14:10:18 +08:00
刘祥超
262f8a5594 已经停用的节点不提示需要升级 2023-06-12 14:04:50 +08:00
刘祥超
a85b49a377 智能DNS实现DoH功能 2023-06-11 17:57:31 +08:00
刘祥超
75e353db0e 初步实现对象存储源站 2023-06-07 17:25:20 +08:00
刘祥超
ccbb14836e 修复因serverId传入0而可能删除WAF策略的问题 2023-06-06 15:03:18 +08:00
刘祥超
7fbc61aa21 改进DNS域名解析相关函数 2023-06-05 12:36:29 +08:00
刘祥超
8b804cb500 修复一个测试用例 2023-06-04 09:38:13 +08:00
刘祥超
3ddb95731a Update sql.json 2023-06-03 09:08:44 +08:00
刘祥超
beeb46ab7f 修复节点IP为IPv6时无法健康检查的问题 2023-06-02 14:46:38 +08:00
刘祥超
a65255e4e5 优化代码 2023-06-01 18:08:45 +08:00
刘祥超
b7768ea0c0 初步实现HTTP3 2023-06-01 17:46:10 +08:00
37 changed files with 31571 additions and 101 deletions

View File

@@ -12,5 +12,5 @@ dbs:
fields:
bool: [ "uamIsOn", "followPort", "requestHostExcludingPort", "autoRemoteStart", "autoInstallNftables", "enableIPLists", "detectAgents", "checkingPorts", "enableRecordHealthCheck", "offlineIsNotified" ]
bool: [ "uamIsOn", "followPort", "requestHostExcludingPort", "autoRemoteStart", "autoInstallNftables", "enableIPLists", "detectAgents", "checkingPorts", "enableRecordHealthCheck", "offlineIsNotified", "http2Enabled", "http3Enabled" ]

9
go.mod
View File

@@ -10,7 +10,7 @@ require (
github.com/andybalholm/brotli v1.0.4
github.com/cespare/xxhash v1.1.0
github.com/cespare/xxhash/v2 v2.1.1
github.com/go-acme/lego/v4 v4.9.0
github.com/go-acme/lego/v4 v4.10.2
github.com/go-sql-driver/mysql v1.7.0
github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible
github.com/iwind/TeaGo v0.0.0-20230304012706-c1f4a4e27470
@@ -20,15 +20,16 @@ require (
github.com/pkg/sftp v1.12.0
github.com/shirou/gopsutil/v3 v3.22.2
github.com/smartwalle/alipay/v3 v3.1.7
golang.org/x/crypto v0.1.0
golang.org/x/crypto v0.5.0
golang.org/x/net v0.8.0
golang.org/x/sys v0.6.0
golang.org/x/sys v0.8.0
google.golang.org/grpc v1.45.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cenkalti/backoff/v4 v4.2.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect

13
go.sum
View File

@@ -12,8 +12,8 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM=
github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4=
github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
@@ -38,6 +38,8 @@ github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
@@ -162,8 +164,8 @@ golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20200513190911-00229845015e/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
@@ -222,8 +224,11 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210816074244-15123e1e1f71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.1.0"
Version = "1.2.0"
ProductName = "Edge API"
ProcessName = "edge-api"
@@ -18,7 +18,7 @@ const (
// 其他节点版本号,用来检测是否有需要升级的节点
NodeVersion = "1.1.0"
NodeVersion = "1.2.0"
// SQLVersion SQL版本号
SQLVersion = "11"

View File

@@ -335,6 +335,7 @@ func (this *APINodeDAO) UpdateAPINodeStatus(tx *dbs.Tx, apiNodeId int64, statusJ
func (this *APINodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(APINodeStateEnabled).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").
Param("version", utils.VersionToLong(version)).

View File

@@ -210,6 +210,7 @@ func (this *AuthorityNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, nodeSta
func (this *AuthorityNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(AuthorityNodeStateEnabled).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").
Param("version", utils.VersionToLong(version)).

View File

@@ -132,7 +132,7 @@ func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64
op.Outbound = outboundJSON
}
if userId <= 0 && serverGroupId <=0 && serverId <= 0 {
if userId <= 0 && serverGroupId <= 0 && serverId <= 0 {
// synFlood
var synFloodConfig = firewallconfigs.DefaultSYNFloodConfig()
synFloodJSON, err := json.Marshal(synFloodConfig)
@@ -611,6 +611,10 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyServerId(tx *dbs.Tx, poli
// FindFirewallPolicyIdsWithServerId 查找服务独立关联的策略IDs
func (this *HTTPFirewallPolicyDAO) FindFirewallPolicyIdsWithServerId(tx *dbs.Tx, serverId int64) ([]int64, error) {
if serverId <= 0 {
return nil, nil
}
var result = []int64{}
ones, err := this.Query(tx).
Attr("serverId", serverId).

View File

@@ -207,6 +207,7 @@ func (this *MonitorNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, statusJSO
func (this *MonitorNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(MonitorNodeStateEnabled).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").
Param("version", utils.VersionToLong(version)).

View File

@@ -996,7 +996,7 @@ func (this *NodeClusterDAO) FindClusterBasicInfo(tx *dbs.Tx, clusterId int64, ca
cluster, err := this.Query(tx).
Pk(clusterId).
State(NodeClusterStateEnabled).
Result("id", "name", "timeZone", "nodeMaxThreads", "cachePolicyId", "httpFirewallPolicyId", "autoOpenPorts", "webp", "uam", "cc", "httpPages", "isOn", "ddosProtection", "clock", "globalServerConfig", "autoInstallNftables").
Result("id", "name", "timeZone", "nodeMaxThreads", "cachePolicyId", "httpFirewallPolicyId", "autoOpenPorts", "webp", "uam", "cc", "httpPages", "http3", "isOn", "ddosProtection", "clock", "globalServerConfig", "autoInstallNftables").
Find()
if err != nil || cluster == nil {
return nil, err
@@ -1184,6 +1184,65 @@ func (this *NodeClusterDAO) FindClusterHTTPCCPolicy(tx *dbs.Tx, clusterId int64,
return policy, nil
}
// UpdateClusterHTTP3Policy 修改HTTP3策略设置
func (this *NodeClusterDAO) UpdateClusterHTTP3Policy(tx *dbs.Tx, clusterId int64, http3Policy *nodeconfigs.HTTP3Policy) error {
if http3Policy == nil {
err := this.Query(tx).
Pk(clusterId).
Set("http3", dbs.SQL("null")).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyHTTP3Update(tx, clusterId)
}
http3PolicyJSON, err := json.Marshal(http3Policy)
if err != nil {
return err
}
err = this.Query(tx).
Pk(clusterId).
Set("http3", http3PolicyJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyHTTP3Update(tx, clusterId)
}
// FindClusterHTTP3Policy 查询HTTP3策略设置
func (this *NodeClusterDAO) FindClusterHTTP3Policy(tx *dbs.Tx, clusterId int64, cacheMap *utils.CacheMap) (*nodeconfigs.HTTP3Policy, error) {
var cacheKey = this.Table + ":FindClusterHTTP3Policy:" + types.String(clusterId)
if cacheMap != nil {
cache, ok := cacheMap.Get(cacheKey)
if ok {
return cache.(*nodeconfigs.HTTP3Policy), nil
}
}
http3PolicyJSON, err := this.Query(tx).
Pk(clusterId).
Result("http3").
FindJSONCol()
if err != nil {
return nil, err
}
if IsNull(http3PolicyJSON) {
return nodeconfigs.NewHTTP3Policy(), nil
}
var policy = nodeconfigs.NewHTTP3Policy()
err = json.Unmarshal(http3PolicyJSON, policy)
if err != nil {
return nil, err
}
return policy, nil
}
// UpdateClusterHTTPPagesPolicy 修改自定义页面设置
func (this *NodeClusterDAO) UpdateClusterHTTPPagesPolicy(tx *dbs.Tx, clusterId int64, httpPagesPolicy *nodeconfigs.HTTPPagesPolicy) error {
if httpPagesPolicy == nil {
@@ -1362,6 +1421,11 @@ func (this *NodeClusterDAO) NotifyHTTPCCUpdate(tx *dbs.Tx, clusterId int64) erro
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTPCCPolicyChanged)
}
// NotifyHTTP3Update 通知HTTP3更新
func (this *NodeClusterDAO) NotifyHTTP3Update(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTP3PolicyChanged)
}
// NotifyHTTPPagesPolicyUpdate 通知HTTP Pages更新
func (this *NodeClusterDAO) NotifyHTTPPagesPolicyUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTPPagesPolicyChanged)

View File

@@ -43,6 +43,7 @@ type NodeCluster struct {
IsAD bool `field:"isAD"` // 是否为高防集群
HttpPages dbs.JSON `field:"httpPages"` // 自定义页面设置
Cc dbs.JSON `field:"cc"` // CC设置
Http3 dbs.JSON `field:"http3"` // HTTP3设置
}
type NodeClusterOperator struct {
@@ -85,6 +86,7 @@ type NodeClusterOperator struct {
IsAD any // 是否为高防集群
HttpPages any // 自定义页面设置
Cc any // CC设置
Http3 any // HTTP3设置
}
func NewNodeClusterOperator() *NodeClusterOperator {

View File

@@ -295,6 +295,7 @@ func (this *NodeDAO) CountAllEnabledNodes(tx *dbs.Tx) (int64, error) {
func (this *NodeDAO) CountAllEnabledOfflineNodes(tx *dbs.Tx) (int64, error) {
return this.Query(tx).
State(NodeStateEnabled).
Attr("isOn", true).
Where("clusterId IN (SELECT id FROM "+SharedNodeClusterDAO.Table+" WHERE state=:clusterState)").
Param("clusterState", NodeClusterStateEnabled).
Where("(status IS NULL OR NOT JSON_EXTRACT(status, '$.isActive') OR UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')>60)").
@@ -1086,6 +1087,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
config.WebPImagePolicies = map[int64]*nodeconfigs.WebPImagePolicy{}
config.UAMPolicies = map[int64]*nodeconfigs.UAMPolicy{}
config.HTTPCCPolicies = map[int64]*nodeconfigs.HTTPCCPolicy{}
config.HTTP3Policies = map[int64]*nodeconfigs.HTTP3Policy{}
config.HTTPPagesPolicies = map[int64]*nodeconfigs.HTTPPagesPolicy{}
var allowIPMaps = map[string]bool{}
for _, clusterId := range clusterIds {
@@ -1189,7 +1191,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
}
// 集成默认设置
for i := 0; i < len(serverconfigs.DefaultHTTPCCThresholds); i ++ {
for i := 0; i < len(serverconfigs.DefaultHTTPCCThresholds); i++ {
if i < len(ccPolicy.Thresholds) {
ccPolicy.Thresholds[i].MergeIfEmpty(serverconfigs.DefaultHTTPCCThresholds[i])
}
@@ -1198,6 +1200,16 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
config.HTTPCCPolicies[clusterId] = ccPolicy
}
// HTTP3 Policy
if IsNotNull(nodeCluster.Http3) {
var http3Policy = nodeconfigs.NewHTTP3Policy()
err = json.Unmarshal(nodeCluster.Http3, http3Policy)
if err != nil {
return nil, err
}
config.HTTP3Policies[clusterId] = http3Policy
}
// HTTP Pages Policy
if IsNotNull(nodeCluster.HttpPages) {
var httpPagesPolicy = nodeconfigs.NewHTTPPagesPolicy()
@@ -1473,6 +1485,7 @@ func (this *NodeDAO) FindAllNotInstalledNodesWithClusterId(tx *dbs.Tx, clusterId
func (this *NodeDAO) CountAllLowerVersionNodesWithClusterId(tx *dbs.Tx, clusterId int64, os string, arch string, version string) (int64, error) {
return this.Query(tx).
State(NodeStateEnabled).
Attr("isOn", true).
Attr("clusterId", clusterId).
Where("status IS NOT NULL").
Where("JSON_EXTRACT(status, '$.os')=:os").
@@ -1506,6 +1519,7 @@ func (this *NodeDAO) FindAllLowerVersionNodesWithClusterId(tx *dbs.Tx, clusterId
func (this *NodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(NodeStateEnabled).
Attr("isOn", true).
Where("clusterId IN (SELECT id FROM "+SharedNodeClusterDAO.Table+" WHERE state=1)").
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").

View File

@@ -27,6 +27,7 @@ const (
NodeTaskTypeUAMPolicyChanged NodeTaskType = "uamPolicyChanged" // UAM策略变化
NodeTaskTypeHTTPPagesPolicyChanged NodeTaskType = "httpPagesPolicyChanged" // 自定义页面变化
NodeTaskTypeHTTPCCPolicyChanged NodeTaskType = "httpCCPolicyChanged" // CC策略变化
NodeTaskTypeHTTP3PolicyChanged NodeTaskType = "http3PolicyChanged" // HTTP3策略变化
NodeTaskTypeUpdatingServers NodeTaskType = "updatingServers" // 更新一组服务
// NS相关

View File

@@ -15,6 +15,7 @@ type NSCluster struct {
Tcp dbs.JSON `field:"tcp"` // TCP设置
Tls dbs.JSON `field:"tls"` // TLS设置
Udp dbs.JSON `field:"udp"` // UDP设置
Doh dbs.JSON `field:"doh"` // DoH设置
DdosProtection dbs.JSON `field:"ddosProtection"` // DDoS防护设置
Hosts dbs.JSON `field:"hosts"` // DNS主机地址
Soa dbs.JSON `field:"soa"` // SOA配置
@@ -39,6 +40,7 @@ type NSClusterOperator struct {
Tcp any // TCP设置
Tls any // TLS设置
Udp any // UDP设置
Doh any // DoH设置
DdosProtection any // DDoS防护设置
Hosts any // DNS主机地址
Soa any // SOA配置

View File

@@ -93,6 +93,7 @@ func (this *NSNodeDAO) CountAllLowerVersionNodesWithClusterId(tx *dbs.Tx, cluste
return this.Query(tx).
State(NSNodeStateEnabled).
Attr("clusterId", clusterId).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("JSON_EXTRACT(status, '$.os')=:os").
Where("JSON_EXTRACT(status, '$.arch')=:arch").
@@ -161,6 +162,7 @@ func (this *NSNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, nodeStatus *no
func (this *NSNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(NSNodeStateEnabled).
Attr("isOn", true).
Where("clusterId IN (SELECT id FROM "+SharedNSClusterDAO.Table+" WHERE state=1)").
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").

View File

@@ -5,6 +5,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ossconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
_ "github.com/go-sql-driver/mysql"
@@ -91,7 +92,8 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx,
adminId int64,
userId int64,
name string,
addrJSON string,
addrJSON []byte,
ossConfig *ossconfigs.OSSConfig,
description string,
weight int32, isOn bool,
connTimeout *shared.TimeDuration,
@@ -141,7 +143,18 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx,
op.MaxIdleConns = 0
}
op.Addr = addrJSON
if len(addrJSON) > 0 {
op.Addr = addrJSON
}
if ossConfig != nil {
ossConfigJSON, err := json.Marshal(ossConfig)
if err != nil {
return 0, err
}
op.Oss = ossConfigJSON
}
op.Description = description
if weight < 0 {
weight = 0
@@ -182,7 +195,8 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx,
func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
originId int64,
name string,
addrJSON string,
addrJSON []byte,
ossConfig *ossconfigs.OSSConfig,
description string,
weight int32,
isOn bool,
@@ -201,7 +215,17 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
var op = NewOriginOperator()
op.Id = originId
op.Name = name
op.Addr = addrJSON
if ossConfig != nil {
ossConfigJSON, err := json.Marshal(ossConfig)
if err != nil {
return err
}
op.Oss = ossConfigJSON
}
op.Description = description
if weight < 0 {
weight = 0
@@ -369,6 +393,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *
FollowPort: origin.FollowPort,
}
// addr
if IsNotNull(origin.Addr) {
var addr = &serverconfigs.NetworkAddressConfig{}
err = json.Unmarshal(origin.Addr, addr)
@@ -378,6 +403,16 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *
config.Addr = addr
}
// oss
if IsNotNull(origin.Oss) {
var ossConfig = ossconfigs.NewOSSConfig()
err = json.Unmarshal(origin.Oss, ossConfig)
if err != nil {
return nil, err
}
config.OSS = ossConfig
}
if IsNotNull(origin.ConnTimeout) {
var connTimeout = &shared.TimeDuration{}
err = json.Unmarshal(origin.ConnTimeout, &connTimeout)

View File

@@ -11,6 +11,7 @@ type Origin struct {
Name string `field:"name"` // 名称
Version uint32 `field:"version"` // 版本
Addr dbs.JSON `field:"addr"` // 地址
Oss dbs.JSON `field:"oss"` // OSS配置
Description string `field:"description"` // 描述
Code string `field:"code"` // 代号
Weight uint32 `field:"weight"` // 权重
@@ -34,33 +35,34 @@ type Origin struct {
}
type OriginOperator struct {
Id interface{} // ID
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
IsOn interface{} // 是否启用
Name interface{} // 名称
Version interface{} // 版本
Addr interface{} // 地址
Description interface{} // 描述
Code interface{} // 代号
Weight interface{} // 权重
ConnTimeout interface{} // 连接超时
ReadTimeout interface{} // 超时
IdleTimeout interface{} // 空闲连接超时
MaxFails interface{} // 最多失败次数
MaxConns interface{} // 最大并发连接
MaxIdleConns interface{} // 最多空闲连接数
HttpRequestURI interface{} // 转发后的请求URI
HttpRequestHeader interface{} // 请求Header配置
HttpResponseHeader interface{} // 响应Header配置
Host interface{} // 自定义主机名
HealthCheck interface{} // 健康检查设置
Cert interface{} // 证书设置
Ftp interface{} // FTP相关设置
CreatedAt interface{} // 创建时间
Domains interface{} // 所属域名
FollowPort interface{} // 端口跟随
State interface{} // 状态
Id any // ID
AdminId any // 管理员ID
UserId any // 用户ID
IsOn any // 是否启用
Name any // 名称
Version any // 版本
Addr any // 地址
Oss any // OSS配置
Description any // 描述
Code any // 代号
Weight any // 权重
ConnTimeout any // 连接超时
ReadTimeout any // 超时
IdleTimeout any // 空闲连接超时
MaxFails any // 最多失败次
MaxConns any // 最大并发连接数
MaxIdleConns any // 最多空闲连接数
HttpRequestURI any // 转发后的请求URI
HttpRequestHeader any // 请求Header配置
HttpResponseHeader any // 响应Header配置
Host any // 自定义主机名
HealthCheck any // 健康检查设置
Cert any // 证书设置
Ftp any // FTP相关设置
CreatedAt any // 创建时间
Domains any // 所属域名
FollowPort any // 端口跟随
State any // 状态
}
func NewOriginOperator() *OriginOperator {

View File

@@ -354,6 +354,9 @@ func (this *ServerDAO) UpdateServerBasic(tx *dbs.Tx, serverId int64, name string
// UpdateServerGroupIds 修改服务所在分组
func (this *ServerDAO) UpdateServerGroupIds(tx *dbs.Tx, serverId int64, groupIds []int64) error {
if serverId <= 0 {
return errors.New("serverId should not be smaller than 0")
}
if groupIds == nil {
groupIds = []int64{}
}
@@ -390,6 +393,10 @@ func (this *ServerDAO) UpdateUserServerBasic(tx *dbs.Tx, serverId int64, name st
// UpdateServerIsOn 修复服务是否启用
func (this *ServerDAO) UpdateServerIsOn(tx *dbs.Tx, serverId int64, isOn bool) error {
if serverId <= 0 {
return errors.New("serverId should not be smaller than 0")
}
_, err := this.Query(tx).
Pk(serverId).
Set("isOn", isOn).
@@ -2153,6 +2160,10 @@ func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, cluster
// NotifyServerPortsUpdate 通知服务端口变化
func (this *ServerDAO) NotifyServerPortsUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 {
return nil
}
one, err := this.Query(tx).
Pk(serverId).
Result("tcp", "tls", "udp", "http", "https").
@@ -2480,6 +2491,10 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
// IncreaseServerTotalTraffic 增加服务的总流量
func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error {
if serverId <= 0 {
return errors.New("serverId should not be smaller than 0")
}
var gb = float64(bytes) / (1 << 30)
var day = timeutil.Format("Ymd")
var month = timeutil.Format("Ym")
@@ -2539,6 +2554,10 @@ func (this *ServerDAO) UpdateServersClusterIdWithPlanId(tx *dbs.Tx, planId int64
// UpdateServerUserPlanId 设置服务所属套餐
func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPlanId int64) error {
if serverId <= 0 {
return errors.New("serverId should not be smaller than 0")
}
oldClusterId, err := this.Query(tx).
Pk(serverId).
Result("clusterId").
@@ -2646,6 +2665,10 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
// FindServerLastUserPlanIdAndUserId 查找最后使用的套餐
func (this *ServerDAO) FindServerLastUserPlanIdAndUserId(tx *dbs.Tx, serverId int64) (userPlanId int64, userId int64, err error) {
if serverId <= 0 {
return 0, 0, errors.New("serverId should not be smaller than 0")
}
one, err := this.Query(tx).
Pk(serverId).
Result("lastUserPlanId", "userId").
@@ -2659,6 +2682,10 @@ func (this *ServerDAO) FindServerLastUserPlanIdAndUserId(tx *dbs.Tx, serverId in
// UpdateServerUAM 开启UAM
func (this *ServerDAO) UpdateServerUAM(tx *dbs.Tx, serverId int64, uamConfig *serverconfigs.UAMConfig) error {
if serverId <= 0 {
return errors.New("serverId should not be smaller than 0")
}
if uamConfig == nil {
return nil
}
@@ -2832,6 +2859,10 @@ func (this *ServerDAO) FindEnabledServersWithIds(tx *dbs.Tx, serverIds []int64)
// NotifyUpdate 同步服务所在的集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 {
return nil
}
// 创建任务
clusterId, err := this.FindServerClusterId(tx, serverId)
if err != nil {
@@ -2845,6 +2876,9 @@ func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
// NotifyClusterUpdate 同步指定的集群
func (this *ServerDAO) NotifyClusterUpdate(tx *dbs.Tx, clusterId, serverId int64) error {
if serverId <= 0 {
return nil
}
if clusterId <= 0 {
return nil
}
@@ -2853,6 +2887,10 @@ func (this *ServerDAO) NotifyClusterUpdate(tx *dbs.Tx, clusterId, serverId int64
// NotifyDNSUpdate 通知当前集群DNS更新
func (this *ServerDAO) NotifyDNSUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 {
return nil
}
clusterId, err := this.Query(tx).
Pk(serverId).
Result("clusterId").
@@ -2878,6 +2916,10 @@ func (this *ServerDAO) NotifyDNSUpdate(tx *dbs.Tx, serverId int64) error {
// NotifyClusterDNSUpdate 通知某个集群中的DNS更新
func (this *ServerDAO) NotifyClusterDNSUpdate(tx *dbs.Tx, clusterId int64, serverId int64) error {
if serverId <= 0 {
return nil
}
dnsInfo, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId, nil)
if err != nil {
return err
@@ -2893,6 +2935,10 @@ func (this *ServerDAO) NotifyClusterDNSUpdate(tx *dbs.Tx, clusterId int64, serve
// NotifyDisable 通知禁用
func (this *ServerDAO) NotifyDisable(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 {
return nil
}
// 禁用缓存策略相关的内容
policyIds, err := SharedHTTPFirewallPolicyDAO.FindFirewallPolicyIdsWithServerId(tx, serverId)
if err != nil {

View File

@@ -229,7 +229,7 @@ func TestServerDAO_FindEnabledServerWithDomain(t *testing.T) {
for _, domain := range []string{"a", "a.com", "teaos.cn", "www.teaos.cn", "cdn.teaos.cn", "google.com"} {
var before = time.Now()
server, err := dao.FindEnabledServerWithDomain(tx, domain)
server, err := dao.FindEnabledServerWithDomain(tx, 0, domain)
var costMs = time.Since(before).Seconds() * 1000
if err != nil {
t.Fatal(err)

View File

@@ -100,7 +100,8 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64, ignore
config.Id = int64(policy.Id)
config.IsOn = policy.IsOn
config.ClientAuthType = int(policy.ClientAuthType)
config.HTTP2Enabled = policy.Http2Enabled == 1
config.HTTP2Enabled = policy.Http2Enabled
config.HTTP3Enabled = policy.Http3Enabled
config.MinVersion = policy.MinVersion
// certs
@@ -200,7 +201,7 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId i
}
// CreatePolicy 创建Policy
func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, ocspIsOn bool, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) {
func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64, http2Enabled bool, http3Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, ocspIsOn bool, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) {
var op = NewSSLPolicyOperator()
op.State = SSLPolicyStateEnabled
op.IsOn = true
@@ -208,6 +209,7 @@ func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64,
op.UserId = userId
op.Http2Enabled = http2Enabled
op.Http3Enabled = http3Enabled
op.MinVersion = minVersion
if len(certsJSON) > 0 {
@@ -240,7 +242,7 @@ func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64,
}
// UpdatePolicy 修改Policy
func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, ocspIsOn bool, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error {
func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled bool, http3Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, ocspIsOn bool, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error {
if policyId <= 0 {
return errors.New("invalid policyId")
}
@@ -248,6 +250,7 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled
var op = NewSSLPolicyOperator()
op.Id = policyId
op.Http2Enabled = http2Enabled
op.Http3Enabled = http3Enabled
op.MinVersion = minVersion
if len(certsJSON) > 0 {

View File

@@ -7,7 +7,7 @@ type SSLPolicy struct {
Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
IsOn bool `field:"isOn"` // 是否启用
IsOn bool `field:"isOn"` // 是否启用
Certs dbs.JSON `field:"certs"` // 证书列表
ClientCACerts dbs.JSON `field:"clientCACerts"` // 客户端证书
ClientAuthType uint32 `field:"clientAuthType"` // 客户端认证类型
@@ -15,28 +15,30 @@ type SSLPolicy struct {
CipherSuitesIsOn uint8 `field:"cipherSuitesIsOn"` // 是否自定义加密算法套件
CipherSuites dbs.JSON `field:"cipherSuites"` // 加密算法套件
Hsts dbs.JSON `field:"hsts"` // HSTS设置
Http2Enabled uint8 `field:"http2Enabled"` // 是否启用HTTP/2
Http2Enabled bool `field:"http2Enabled"` // 是否启用HTTP/2
Http3Enabled bool `field:"http3Enabled"` // 是否启用HTTP/3
OcspIsOn uint8 `field:"ocspIsOn"` // 是否启用OCSP
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
}
type SSLPolicyOperator struct {
Id interface{} // ID
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
IsOn interface{} // 是否启用
Certs interface{} // 证书列表
ClientCACerts interface{} // 客户端证书
ClientAuthType interface{} // 客户端认证类型
MinVersion interface{} // 支持的SSL最小版本
CipherSuitesIsOn interface{} // 是否自定义加密算法套件
CipherSuites interface{} // 加密算法套件
Hsts interface{} // HSTS设置
Http2Enabled interface{} // 是否启用HTTP/2
OcspIsOn interface{} // 是否启用OCSP
State interface{} // 状态
CreatedAt interface{} // 创建时间
Id any // ID
AdminId any // 管理员ID
UserId any // 用户ID
IsOn any // 是否启用
Certs any // 证书列表
ClientCACerts any // 客户端证书
ClientAuthType any // 客户端认证类型
MinVersion any // 支持的SSL最小版本
CipherSuitesIsOn any // 是否自定义加密算法套件
CipherSuites any // 加密算法套件
Hsts any // HSTS设置
Http2Enabled any // 是否启用HTTP/2
Http3Enabled any // 是否启用HTTP/3
OcspIsOn any // 是否启用OCSP
State any // 状态
CreatedAt any // 创建时间
}
func NewSSLPolicyOperator() *SSLPolicyOperator {

View File

@@ -280,6 +280,7 @@ func (this *UserNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, nodeStatus *
func (this *UserNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(UserNodeStateEnabled).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").
Param("version", utils.VersionToLong(version)).

View File

@@ -28,6 +28,7 @@ import (
"github.com/iwind/gosock/pkg/gosock"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"log"
"net"
@@ -832,7 +833,12 @@ func (this *APINode) unaryInterceptor(ctx context.Context, req any, info *grpc.U
}
result, err := handler(ctx, req)
if err != nil {
err = errors.New("'" + info.FullMethod + "()' says: " + err.Error())
statusErr, ok := status.FromError(err)
if ok {
err = status.Error(statusErr.Code(), "'" + info.FullMethod + "()' says: " + err.Error())
} else {
err = errors.New("'" + info.FullMethod + "()' says: " + err.Error())
}
}
return result, err
}

View File

@@ -16,7 +16,9 @@ import (
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type BaseService struct {
@@ -229,7 +231,7 @@ func (this *BaseService) PermissionError() error {
}
func (this *BaseService) NotImplementedYet() error {
return errors.New("not implemented yet")
return status.Error(codes.Unimplemented, "not implemented yet")
}
// NullTx 空的数据库事务

View File

@@ -1180,6 +1180,16 @@ func (this *NodeClusterService) FindEnabledNodeClusterConfigInfo(ctx context.Con
result.HasHTTPPagesPolicy = pagesPolicy.IsOn && len(pagesPolicy.Pages) > 0
}
// HTTP/3
if models.IsNotNull(cluster.Http3) {
var http3Policy = nodeconfigs.NewHTTP3Policy()
err = json.Unmarshal(cluster.Http3, http3Policy)
if err != nil {
return nil, err
}
result.Http3IsOn = http3Policy.IsOn
}
return result, nil
}
@@ -1301,7 +1311,6 @@ func (this *NodeClusterService) UpdateNodeClusterUAMPolicy(ctx context.Context,
return this.Success()
}
// FindEnabledNodeClusterHTTPCCPolicy 读取集群HTTP CC策略
func (this *NodeClusterService) FindEnabledNodeClusterHTTPCCPolicy(ctx context.Context, req *pb.FindEnabledNodeClusterHTTPCCPolicyRequest) (*pb.FindEnabledNodeClusterHTTPCCPolicyResponse, error) {
if !teaconst.IsPlus {
@@ -1502,3 +1511,33 @@ func (this *NodeClusterService) UpdateNodeClusterHTTPPagesPolicy(ctx context.Con
return this.Success()
}
// UpdateNodeClusterHTTP3Policy 修改集群的HTTP3设置
func (this *NodeClusterService) UpdateNodeClusterHTTP3Policy(ctx context.Context, req *pb.UpdateNodeClusterHTTP3PolicyRequest) (*pb.RPCSuccess, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var http3Policy = nodeconfigs.NewHTTP3Policy()
err = json.Unmarshal(req.Http3PolicyJSON, http3Policy)
if err != nil {
return nil, err
}
err = http3Policy.Init()
if err != nil {
return nil, errors.New("validate http3 policy failed: " + err.Error())
}
var tx = this.NullTx()
err = models.SharedNodeClusterDAO.UpdateClusterHTTP3Policy(tx, req.NodeClusterId, http3Policy)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,13 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package services
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
func (this *NodeClusterService) FindNodeClusterHTTP3Policy(ctx context.Context, req *pb.FindNodeClusterHTTP3PolicyRequest) (*pb.FindNodeClusterHTTP3PolicyResponse, error) {
return nil, this.NotImplementedYet()
}

View File

@@ -16,6 +16,10 @@ func (this *NodeService) FindNodeHTTPCCPolicies(ctx context.Context, req *pb.Fin
return nil, this.NotImplementedYet()
}
func (this *NodeService) FindNodeHTTP3Policies(ctx context.Context, req *pb.FindNodeHTTP3PoliciesRequest) (*pb.FindNodeHTTP3PoliciesResponse, error) {
return nil, this.NotImplementedYet()
}
func (this *NodeService) FindNodeHTTPPagesPolicies(ctx context.Context, req *pb.FindNodeHTTPPagesPoliciesRequest) (*pb.FindNodeHTTPPagesPoliciesResponse, error) {
return nil, this.NotImplementedYet()
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ossconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/maps"
@@ -23,15 +24,26 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi
return nil, err
}
// 源站地址设置
if req.Addr == nil {
return nil, errors.New("'addr' can not be nil")
}
addrMap := maps.Map{
var addrMap = maps.Map{
"protocol": req.Addr.Protocol,
"portRange": req.Addr.PortRange,
"host": req.Addr.Host,
}
// OSS设置
var ossConfig *ossconfigs.OSSConfig
if len(req.OssJSON) > 0 {
ossConfig = ossconfigs.NewOSSConfig()
err = json.Unmarshal(req.OssJSON, ossConfig)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
// 校验参数
@@ -72,7 +84,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi
}
}
originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort)
originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, addrMap.AsJSON(), ossConfig, req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort)
if err != nil {
return nil, err
}
@@ -95,6 +107,8 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi
return nil, err
}
}
// 源站地址设置
if req.Addr == nil {
return nil, errors.New("'addr' can not be nil")
}
@@ -104,6 +118,16 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi
"host": req.Addr.Host,
}
// OSS设置
var ossConfig *ossconfigs.OSSConfig
if len(req.OssJSON) > 0 {
ossConfig = ossconfigs.NewOSSConfig()
err = json.Unmarshal(req.OssJSON, ossConfig)
if err != nil {
return nil, err
}
}
// 校验参数
var connTimeout = &shared.TimeDuration{}
if len(req.ConnTimeoutJSON) > 0 {
@@ -142,7 +166,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi
}
}
err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort)
err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, addrMap.AsJSON(), ossConfig, req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort)
if err != nil {
return nil, err
}

View File

@@ -9,13 +9,17 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models/clients"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/domainutils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"net"
"net/url"
"regexp"
"strings"
)
@@ -184,6 +188,505 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe
return &pb.CreateServerResponse{ServerId: serverId}, nil
}
// CreateBasicHTTPServer 快速创建基本的HTTP网站
func (this *ServerService) CreateBasicHTTPServer(ctx context.Context, req *pb.CreateBasicHTTPServerRequest) (*pb.CreateBasicHTTPServerResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 集群
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
req.NodeClusterId = nodeClusterId
} else if adminId > 0 && req.UserId > 0 && req.NodeClusterId <= 0 {
// check user
existUser, err := models.SharedUserDAO.Exist(tx, req.UserId)
if err != nil {
return nil, err
}
if !existUser {
return nil, errors.New("user id '" + types.String(req.UserId) + "' not found")
}
nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
req.NodeClusterId = nodeClusterId
}
if req.NodeClusterId <= 0 {
return nil, errors.New("invalid 'nodeClusterId'")
}
if len(req.Domains) == 0 {
return nil, errors.New("'domains' should not be empty")
}
var serverNames = []*serverconfigs.ServerNameConfig{}
for _, domain := range req.Domains {
domain = strings.ToLower(domain)
if !domainutils.ValidateDomainFormat(domain) {
return nil, errors.New("invalid domain format '" + domain + "'")
}
// 检查域名是否已存在
existServerName, err := models.SharedServerDAO.ExistServerNameInCluster(tx, req.NodeClusterId, domain, 0, true)
if err != nil {
return nil, err
}
if existServerName {
return nil, errors.New("domain '" + domain + "' already created by other server")
}
serverNames = append(serverNames, &serverconfigs.ServerNameConfig{Name: domain})
}
serverNamesJSON, err := json.Marshal(serverNames)
if err != nil {
return nil, errors.New("encode 'serverNames' failed: " + err.Error())
}
// 是否需要审核
var isAuditing = false
var auditingServerNamesJSON = []byte("[]")
if userId > 0 {
// 如果域名不为空的时候需要审核
if len(serverNamesJSON) > 0 && string(serverNamesJSON) != "[]" {
globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx)
if err != nil {
return nil, err
}
if globalConfig != nil && globalConfig.HTTPAll.DomainAuditingIsOn {
isAuditing = true
serverNamesJSON = []byte("[]")
auditingServerNamesJSON = serverNamesJSON
}
}
}
// HTTP
var httpConfig = &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: "http",
PortRange: "80",
},
},
},
}
httpJSON, err := json.Marshal(httpConfig)
if err != nil {
return nil, err
}
// HTTPS
var certRefs = []*sslconfigs.SSLCertRef{}
for _, certId := range req.SslCertIds {
// 检查所有权
if userId > 0 {
err = models.SharedSSLCertDAO.CheckUserCert(tx, certId, userId)
if err != nil {
return nil, errors.New("check cert permission failed: " + err.Error())
}
} else {
existCert, err := models.SharedSSLCertDAO.Exist(tx, certId)
if err != nil {
return nil, err
}
if !existCert {
return nil, errors.New("cert '" + types.String(certId) + "' not found")
}
}
certRefs = append(certRefs, &sslconfigs.SSLCertRef{
IsOn: true,
CertId: certId,
})
}
certRefsJSON, err := json.Marshal(certRefs)
if err != nil {
return nil, err
}
sslPolicyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, req.UserId, false, false, "TLS 1.0", certRefsJSON, nil, false, 0, nil, false, nil)
if err != nil {
return nil, err
}
var httpsConfig = &serverconfigs.HTTPSProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: "https",
PortRange: "443",
},
},
},
SSLPolicyRef: &sslconfigs.SSLPolicyRef{
IsOn: true,
SSLPolicyId: sslPolicyId,
},
}
httpsJSON, err := json.Marshal(httpsConfig)
if err != nil {
return nil, err
}
// Reverse Proxy
var reverseProxyScheduleConfig = &serverconfigs.SchedulingConfig{
Code: "random",
Options: nil,
}
reverseProxyScheduleJSON, err := json.Marshal(reverseProxyScheduleConfig)
var primaryOrigins = []*serverconfigs.OriginRef{}
for _, originAddr := range req.OriginAddrs {
u, err := url.Parse(originAddr)
if err != nil {
return nil, errors.New("parse origin address '" + originAddr + "' failed: " + err.Error())
}
if len(u.Scheme) == 0 || (u.Scheme != "http" && u.Scheme != "https" /** 特意不支持大写形式 **/) {
return nil, errors.New("invalid scheme in origin address '" + originAddr + "'")
}
if len(u.Host) == 0 {
return nil, errors.New("invalid host address '" + originAddr + "', contains no host")
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
err = nil // ignore error
if domainutils.ValidateDomainFormat(u.Host) { // host with no port
host = u.Host
port = ""
} else {
return nil, errors.New("invalid host address '" + originAddr + "', invalid host format")
}
}
if len(port) == 0 {
switch u.Scheme {
case "http":
port = "80"
case "https":
port = "443"
}
}
var addr = &serverconfigs.NetworkAddressConfig{
Protocol: serverconfigs.Protocol(u.Scheme),
Host: host,
PortRange: port,
}
addrJSON, err := json.Marshal(addr)
if err != nil {
return nil, err
}
originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, req.UserId, "", addrJSON, nil, "", 10, true, nil, nil, nil, 0, 0, nil, nil, u.Host, false)
if err != nil {
return nil, err
}
primaryOrigins = append(primaryOrigins, &serverconfigs.OriginRef{
IsOn: true,
OriginId: originId,
})
}
primaryOriginsJSON, err := json.Marshal(primaryOrigins)
if err != nil {
return nil, err
}
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, req.UserId, reverseProxyScheduleJSON, primaryOriginsJSON, nil)
if err != nil {
return nil, err
}
reverseProxyJSON, err := json.Marshal(&serverconfigs.ReverseProxyRef{
IsPrior: false,
IsOn: true,
ReverseProxyId: reverseProxyId,
})
if err != nil {
return nil, err
}
// Web
webId, err := models.SharedHTTPWebDAO.CreateWeb(tx, adminId, req.UserId, nil)
if err != nil {
return nil, err
}
// Enable websocket
if req.EnableWebsocket {
websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(tx, nil, true, nil, true, "")
if err != nil {
return nil, err
}
websocketRef, err := json.Marshal(&serverconfigs.HTTPWebsocketRef{
IsPrior: false,
IsOn: true,
WebsocketId: websocketId,
})
if err != nil {
return nil, err
}
err = models.SharedHTTPWebDAO.UpdateWebsocket(tx, webId, websocketRef)
if err != nil {
return nil, err
}
}
// finally, we create ...
serverId, err := models.SharedServerDAO.CreateServer(tx, adminId, req.UserId, serverconfigs.ServerTypeHTTPProxy, req.Domains[0], "", serverNamesJSON, isAuditing, auditingServerNamesJSON, httpJSON, httpsJSON, nil, nil, nil, nil, webId, reverseProxyJSON, req.NodeClusterId, nil, nil, nil, 0)
if err != nil {
return nil, err
}
return &pb.CreateBasicHTTPServerResponse{ServerId: serverId}, nil
}
// CreateBasicTCPServer 快速创建基本的TCP网站
func (this *ServerService) CreateBasicTCPServer(ctx context.Context, req *pb.CreateBasicTCPServerRequest) (*pb.CreateBasicTCPServerResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 集群
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
req.NodeClusterId = nodeClusterId
} else if adminId > 0 && req.UserId > 0 && req.NodeClusterId <= 0 {
// check user
existUser, err := models.SharedUserDAO.Exist(tx, req.UserId)
if err != nil {
return nil, err
}
if !existUser {
return nil, errors.New("user id '" + types.String(req.UserId) + "' not found")
}
nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
req.NodeClusterId = nodeClusterId
}
if req.NodeClusterId <= 0 {
return nil, errors.New("invalid 'nodeClusterId'")
}
// 检查用户权限
if userId > 0 {
features, err := models.SharedUserDAO.FindUserFeatures(tx, userId)
if err != nil {
return nil, err
}
var canSpecifyTCPPort = false
for _, feature := range features {
if feature.Code == "server.tcp.port" {
canSpecifyTCPPort = true
break
}
}
if !canSpecifyTCPPort {
if len(req.TcpPorts) > 0 || len(req.TlsPorts) > 0 {
return nil, errors.New("no permission to specify tcp/tls ports")
}
}
}
if len(req.TcpPorts) == 0 || len(req.TlsPorts) == 0 {
// TODO 未来支持自动创建端口
return nil, errors.New("no ports valid")
}
// TCP
var tcpConfig = &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{},
},
}
for _, port := range req.TcpPorts {
existPort, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, "tcp", int(port), 0, "")
if err != nil {
return nil, err
}
if existPort {
return nil, errors.New("port '" + types.String(port) + "' already used by other server")
}
tcpConfig.BaseProtocol.Listen = append(tcpConfig.BaseProtocol.Listen, &serverconfigs.NetworkAddressConfig{
Protocol: "tcp",
PortRange: types.String(port),
})
}
tcpJSON, err := json.Marshal(tcpConfig)
if err != nil {
return nil, err
}
// TLS
var tlsConfig = &serverconfigs.HTTPSProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{},
},
}
for _, port := range req.TlsPorts {
existPort, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, "tcp", int(port), 0, "")
if err != nil {
return nil, err
}
if existPort {
return nil, errors.New("port '" + types.String(port) + "' already used by other server")
}
tlsConfig.BaseProtocol.Listen = append(tlsConfig.BaseProtocol.Listen, &serverconfigs.NetworkAddressConfig{
Protocol: "tls",
PortRange: types.String(port),
})
}
var certRefs = []*sslconfigs.SSLCertRef{}
for _, certId := range req.SslCertIds {
// 检查所有权
if userId > 0 {
err = models.SharedSSLCertDAO.CheckUserCert(tx, certId, userId)
if err != nil {
return nil, errors.New("check cert permission failed: " + err.Error())
}
} else {
existCert, err := models.SharedSSLCertDAO.Exist(tx, certId)
if err != nil {
return nil, err
}
if !existCert {
return nil, errors.New("cert '" + types.String(certId) + "' not found")
}
}
certRefs = append(certRefs, &sslconfigs.SSLCertRef{
IsOn: true,
CertId: certId,
})
}
certRefsJSON, err := json.Marshal(certRefs)
if err != nil {
return nil, err
}
sslPolicyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, req.UserId, false, false, "TLS 1.0", certRefsJSON, nil, false, 0, nil, false, nil)
if err != nil {
return nil, err
}
tlsConfig.SSLPolicyRef = &sslconfigs.SSLPolicyRef{
IsOn: true,
SSLPolicyId: sslPolicyId,
}
tlsJSON, err := json.Marshal(tlsConfig)
if err != nil {
return nil, err
}
// Reverse Proxy
var reverseProxyScheduleConfig = &serverconfigs.SchedulingConfig{
Code: "random",
Options: nil,
}
reverseProxyScheduleJSON, err := json.Marshal(reverseProxyScheduleConfig)
var primaryOrigins = []*serverconfigs.OriginRef{}
for _, originAddr := range req.OriginAddrs {
u, err := url.Parse(originAddr)
if err != nil {
return nil, errors.New("parse origin address '" + originAddr + "' failed: " + err.Error())
}
if len(u.Scheme) == 0 || (u.Scheme != "tcp" && u.Scheme != "tls" && u.Scheme != "ssl" /** 特意不支持大写形式 **/) {
return nil, errors.New("invalid scheme in origin address '" + originAddr + "'")
}
if len(u.Host) == 0 {
return nil, errors.New("invalid host address '" + originAddr + "', contains no host")
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil || len(host) == 0 || len(port) == 0 {
err = nil // ignore error
return nil, errors.New("invalid host address '" + originAddr + "', invalid host format")
}
if u.Scheme == "ssl" {
u.Scheme = "tls"
}
var addr = &serverconfigs.NetworkAddressConfig{
Protocol: serverconfigs.Protocol(u.Scheme),
Host: host,
PortRange: port,
}
addrJSON, err := json.Marshal(addr)
if err != nil {
return nil, err
}
originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, req.UserId, "", addrJSON, nil, "", 10, true, nil, nil, nil, 0, 0, nil, nil, "", false)
if err != nil {
return nil, err
}
primaryOrigins = append(primaryOrigins, &serverconfigs.OriginRef{
IsOn: true,
OriginId: originId,
})
}
primaryOriginsJSON, err := json.Marshal(primaryOrigins)
if err != nil {
return nil, err
}
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, req.UserId, reverseProxyScheduleJSON, primaryOriginsJSON, nil)
if err != nil {
return nil, err
}
reverseProxyJSON, err := json.Marshal(&serverconfigs.ReverseProxyRef{
IsPrior: false,
IsOn: true,
ReverseProxyId: reverseProxyId,
})
if err != nil {
return nil, err
}
// finally, we create ...
serverId, err := models.SharedServerDAO.CreateServer(tx, adminId, req.UserId, serverconfigs.ServerTypeTCPProxy, "TCP Service", "", nil, false, nil, nil, nil, tcpJSON, tlsJSON, nil, nil, 0, reverseProxyJSON, req.NodeClusterId, nil, nil, nil, 0)
if err != nil {
return nil, err
}
return &pb.CreateBasicTCPServerResponse{ServerId: serverId}, nil
}
// UpdateServerBasic 修改服务基本信息
func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.UpdateServerBasicRequest) (*pb.RPCSuccess, error) {
// 校验请求

View File

@@ -44,7 +44,7 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat
// TODO
}
policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, userId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.OcspIsOn, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, userId, req.Http2Enabled, req.Http3Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.OcspIsOn, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
if err != nil {
return nil, err
}
@@ -63,13 +63,13 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat
var tx = this.NullTx()
if userId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, req.SslPolicyId)
err = models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, req.SslPolicyId)
if err != nil {
return nil, errors.New("check ssl policy failed: " + err.Error())
}
}
err = models.SharedSSLPolicyDAO.UpdatePolicy(tx, req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.OcspIsOn, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
err = models.SharedSSLPolicyDAO.UpdatePolicy(tx, req.SslPolicyId, req.Http2Enabled, req.Http3Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.OcspIsOn, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
if err != nil {
return nil, err
}

File diff suppressed because it is too large Load Diff

View File

@@ -260,7 +260,7 @@ func (this *HealthCheckExecutor) runNode(healthCheckConfig *serverconfigs.Health
func (this *HealthCheckExecutor) runNodeOnce(healthCheckConfig *serverconfigs.HealthCheckConfig, result *HealthCheckResult) error {
// 支持IPv6
if utils.IsIPv6(result.NodeAddr) {
result.NodeAddr = "[" + result.NodeAddr + "]"
result.NodeAddr = configutils.QuoteIP(result.NodeAddr)
}
if len(healthCheckConfig.URL) == 0 {

View File

@@ -0,0 +1,31 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package domainutils
import (
"regexp"
"strings"
)
// ValidateDomainFormat 校验域名格式
func ValidateDomainFormat(domain string) bool {
pieces := strings.Split(domain, ".")
for _, piece := range pieces {
if piece == "-" ||
strings.HasPrefix(piece, "-") ||
strings.HasSuffix(piece, "-") ||
//strings.Contains(piece, "--") ||
len(piece) > 63 ||
// 支持中文、大写字母、下划线
!regexp.MustCompile(`^[\p{Han}_a-zA-Z0-9-]+$`).MatchString(piece) {
return false
}
}
// 最后一段不能是全数字
if regexp.MustCompile(`^(\d+)$`).MatchString(pieces[len(pieces)-1]) {
return false
}
return true
}

View File

@@ -3,20 +3,27 @@ package utils
import (
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/utils/taskutils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/fsnotify/fsnotify"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/miekg/dns"
"sync"
)
var sharedDNSClient *dns.Client
var sharedDNSConfig *dns.ClientConfig
var sharedDNSLocker = &sync.RWMutex{}
func init() {
if !teaconst.IsMain {
return
}
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
var resolvConfFile = "/etc/resolv.conf"
config, err := dns.ClientConfigFromFile(resolvConfFile)
if err != nil {
logs.Println("ERROR: configure dns client failed: " + err.Error())
return
@@ -25,6 +32,21 @@ func init() {
sharedDNSConfig = config
sharedDNSClient = &dns.Client{}
// 监视文件变化,以便及时更新配置
go func() {
watcher, watcherErr := fsnotify.NewWatcher()
if watcherErr == nil {
err = watcher.Add(resolvConfFile)
for range watcher.Events {
newConfig, err := dns.ClientConfigFromFile(resolvConfFile)
if err == nil && newConfig != nil {
sharedDNSLocker.Lock()
sharedDNSConfig = newConfig
sharedDNSLocker.Unlock()
}
}
}
}()
}
// LookupCNAME 查询CNAME记录
@@ -40,8 +62,10 @@ func LookupCNAME(host string) (string, error) {
m.RecursionDesired = true
var lastErr error
for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
var serverAddrs = composeDNSResolverAddrs(nil)
for _, serverAddr := range serverAddrs {
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
if err != nil {
lastErr = err
continue
@@ -56,8 +80,7 @@ func LookupCNAME(host string) (string, error) {
}
// LookupNS 查询NS记录
// TODO 可以设置使用的DNS主机地址
func LookupNS(host string) ([]string, error) {
func LookupNS(host string, extraResolvers []*dnsconfigs.DNSResolver) ([]string, error) {
var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeNS)
@@ -67,23 +90,36 @@ func LookupNS(host string) ([]string, error) {
var lastErr error
var hasValidServer = false
for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
var serverAddrs = composeDNSResolverAddrs(extraResolvers)
if len(serverAddrs) == 0 {
return nil, nil
}
taskErr := taskutils.RunConcurrent(serverAddrs, taskutils.DefaultConcurrent, func(task any, locker *sync.RWMutex) {
var serverAddr = task.(string)
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
if err != nil {
lastErr = err
continue
return
}
hasValidServer = true
if len(r.Answer) == 0 {
continue
return
}
for _, answer := range r.Answer {
result = append(result, answer.(*dns.NS).Ns)
var value = answer.(*dns.NS).Ns
locker.Lock()
if len(value) > 0 && !lists.ContainsString(result, value) {
result = append(result, value)
}
locker.Unlock()
}
break
})
if taskErr != nil {
return result, taskErr
}
if hasValidServer {
@@ -94,8 +130,7 @@ func LookupNS(host string) ([]string, error) {
}
// LookupTXT 获取CNAME
// TODO 可以设置使用的DNS主机地址
func LookupTXT(host string) ([]string, error) {
func LookupTXT(host string, extraResolvers []*dnsconfigs.DNSResolver) ([]string, error) {
var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeTXT)
@@ -104,23 +139,36 @@ func LookupTXT(host string) ([]string, error) {
var lastErr error
var result = []string{}
var hasValidServer = false
for _, serverAddr := range sharedDNSConfig.Servers {
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
var serverAddrs = composeDNSResolverAddrs(extraResolvers)
if len(serverAddrs) == 0 {
return nil, nil
}
taskErr := taskutils.RunConcurrent(serverAddrs, taskutils.DefaultConcurrent, func(task any, locker *sync.RWMutex) {
var serverAddr = task.(string)
r, _, err := sharedDNSClient.Exchange(m, serverAddr)
if err != nil {
lastErr = err
continue
return
}
hasValidServer = true
if len(r.Answer) == 0 {
continue
return
}
for _, answer := range r.Answer {
result = append(result, answer.(*dns.TXT).Txt...)
for _, txt := range answer.(*dns.TXT).Txt {
locker.Lock()
if len(txt) > 0 && !lists.ContainsString(result, txt) {
result = append(result, txt)
}
locker.Unlock()
}
}
break
})
if taskErr != nil {
return result, taskErr
}
if hasValidServer {
@@ -129,3 +177,22 @@ func LookupTXT(host string) ([]string, error) {
return nil, lastErr
}
// 组合DNS解析服务器地址
func composeDNSResolverAddrs(extraResolvers []*dnsconfigs.DNSResolver) []string {
sharedDNSLocker.RLock()
defer sharedDNSLocker.RUnlock()
// 这里不处理重复,方便我们可以多次重试
var servers = sharedDNSConfig.Servers
var port = sharedDNSConfig.Port
var serverAddrs = []string{}
for _, serverAddr := range servers {
serverAddrs = append(serverAddrs, configutils.QuoteIP(serverAddr)+":"+port)
}
for _, resolver := range extraResolvers {
serverAddrs = append(serverAddrs, resolver.Addr())
}
return serverAddrs
}

View File

@@ -4,6 +4,7 @@ package utils_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"testing"
)
@@ -12,9 +13,25 @@ func TestLookupCNAME(t *testing.T) {
}
func TestLookupNS(t *testing.T) {
t.Log(utils.LookupNS("goedge.cn"))
t.Log(utils.LookupNS("goedge.cn", nil))
}
func TestLookupNSExtra(t *testing.T) {
t.Log(utils.LookupNS("goedge.cn", []*dnsconfigs.DNSResolver{
{
Host: "192.168.2.2",
},
{
Host: "192.168.2.2",
Port: 58,
},
{
Host: "8.8.8.8",
Port: 53,
},
}))
}
func TestLookupTXT(t *testing.T) {
t.Log(utils.LookupTXT("yanzheng.goedge.cn"))
t.Log(utils.LookupTXT("yanzheng.goedge.cn", nil))
}

View File

@@ -9,3 +9,7 @@ var (
YYYYMMDD = regexp.MustCompile(`^\d{8}$`)
YYYYMM = regexp.MustCompile(`^\d{6}$`)
)
var (
HTTPProtocol = regexp.MustCompile("^(?i)(http|https)://")
)

View File

@@ -0,0 +1,61 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package taskutils
import (
"errors"
"reflect"
"sync"
)
const DefaultConcurrent = 16
func RunConcurrent(tasks any, concurrent int, f func(task any, locker *sync.RWMutex)) error {
if tasks == nil {
return nil
}
var tasksValue = reflect.ValueOf(tasks)
if tasksValue.Type().Kind() != reflect.Slice {
return errors.New("ony works for slice")
}
var countTasks = tasksValue.Len()
if countTasks == 0 {
return nil
}
if concurrent <= 0 {
concurrent = 8
}
if concurrent > countTasks {
concurrent = countTasks
}
var taskChan = make(chan any, countTasks)
for i := 0; i < countTasks; i++ {
taskChan <- tasksValue.Index(i).Interface()
}
var wg = &sync.WaitGroup{}
wg.Add(concurrent)
var locker = &sync.RWMutex{}
for i := 0; i < concurrent; i++ {
go func() {
defer wg.Done()
for {
select {
case task := <-taskChan:
f(task, locker)
default:
return
}
}
}()
}
wg.Wait()
return nil
}

View File

@@ -0,0 +1,18 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package taskutils_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/utils/taskutils"
"sync"
"testing"
)
func TestRunConcurrent(t *testing.T) {
err := taskutils.RunConcurrent([]string{"a", "b", "c", "d", "e"}, 3, func(task any, locker *sync.RWMutex) {
t.Log("run", task)
})
if err != nil {
t.Fatal(err)
}
}