Compare commits

..

40 Commits

Author SHA1 Message Date
刘祥超
6dfecd69b4 提供区域监控上报结果接口 2021-09-06 08:12:48 +08:00
刘祥超
dbc97bc8de 实现基本的区域监控终端管理功能 2021-09-05 11:10:18 +08:00
刘祥超
8308e2e83d 修复边缘节点无法下载文件的Bug 2021-09-03 15:15:27 +08:00
刘祥超
8227138168 修复DNS节点升级文件无法下载的Bug 2021-08-31 22:36:36 +08:00
刘祥超
2227a14ba4 增加独立的IP地址管理功能 2021-08-31 17:24:52 +08:00
刘祥超
6137b44408 企业认证信息中增加节点数限制 2021-08-30 18:57:11 +08:00
刘祥超
f654c65626 查询指标数据时增加索引 2021-08-30 15:23:51 +08:00
刘祥超
674574a6af Update go.mod 2021-08-30 10:56:44 +08:00
刘祥超
f02ab1aae2 优化数据库节点管理 2021-08-30 10:56:31 +08:00
刘祥超
55527bba09 健康检查失败10分钟内不重复提醒 2021-08-30 09:47:17 +08:00
刘祥超
821b607ef2 当修改节点在线状态时重置上线检查次数 2021-08-29 16:57:50 +08:00
刘祥超
afeee89a88 节点返回数据增加isUp字段 2021-08-29 16:57:20 +08:00
刘祥超
a983615464 修复节点转移集群后没有删除老的DNS记录的问题 2021-08-29 16:41:42 +08:00
刘祥超
bd596816d5 将健康检查连续下线次数从1升级为3/修复健康检查可能导致DNS不断同步的问题 2021-08-29 16:01:31 +08:00
刘祥超
ca233c3573 修复一个因为SQL_CACHE而导致子查询产生的错误 2021-08-29 14:07:12 +08:00
刘祥超
fd1f990a0e 访问日志表增加requestBody, responseBody(预留) 2021-08-29 10:37:42 +08:00
刘祥超
1796bb8f96 更新TeaGo,提升SQL解析效率、自动开启SQL查询缓存 2021-08-29 10:21:42 +08:00
刘祥超
3f5e4babc7 改进编译脚本 2021-08-26 16:29:19 +08:00
刘祥超
f508c16f92 删除plus文件 2021-08-26 14:40:07 +08:00
刘祥超
ac0bbd0b99 DNS服务商支持搜索 2021-08-25 18:47:01 +08:00
刘祥超
9017176efb 集群支持使用域名搜索 2021-08-25 18:39:17 +08:00
刘祥超
8b40634e74 节点如果没有设置DNS线路就使用默认线路 2021-08-25 17:16:24 +08:00
刘祥超
cf476f79d6 Admin看板增加默认集群ID 2021-08-25 11:41:23 +08:00
刘祥超
fc38a6ab7e 创建集群时自动创建缓存策略和WAF策略 2021-08-25 11:18:37 +08:00
刘祥超
e4f0dafc1a 增加忽略相似消息周期设置 2021-08-24 20:45:12 +08:00
刘祥超
b7fda0b9cc 消息接收人可以设置接收消息时间段 2021-08-24 17:46:11 +08:00
刘祥超
f4cc5aa087 通知媒介可以设置发送频率 2021-08-24 15:46:53 +08:00
刘祥超
f58724065d 通知媒介增加任务队列查看功能 2021-08-24 14:22:44 +08:00
刘祥超
1e6b42c00c 优化WAF日志查询速度 2021-08-22 16:20:40 +08:00
刘祥超
8d759a104b 优化WAF日志访问速度 2021-08-22 16:00:32 +08:00
刘祥超
72d7ceb94e 提升节点组合配置效率 2021-08-22 11:35:33 +08:00
刘祥超
53f7a0b77e Dashboard可以提示API节点升级 2021-08-21 19:43:46 +08:00
刘祥超
56000a8b8a 优化服务配置更新机制 2021-08-21 17:24:29 +08:00
刘祥超
ab7b2fee3a 自建DNS支持递归查询 2021-08-21 16:46:41 +08:00
刘祥超
d768d46854 DNS访问日志显示匹配的线路 2021-08-20 11:27:16 +08:00
刘祥超
b86c9aad6f 节点排行增加条数限制 2021-08-20 10:10:55 +08:00
刘祥超
df667c6ee6 添加DNS账号时自动读取DNS服务商下域名 2021-08-19 14:26:34 +08:00
刘祥超
70331805d7 节点IP地址可以设置阈值 2021-08-18 16:19:16 +08:00
刘祥超
71dbf86572 节点IP增加是否启用、是否在线状态 2021-08-18 09:24:18 +08:00
刘祥超
0df358d70d 调整版本为0.3.0 2021-08-17 09:44:22 +08:00
132 changed files with 3324 additions and 746 deletions

View File

@@ -7,6 +7,7 @@ function build() {
OS=${1} OS=${1}
ARCH=${2} ARCH=${2}
TAG=${3} TAG=${3}
NODE_ARCHITECTS=("amd64" "386" "arm64" "mips64" "mips64le")
if [ -z $OS ]; then if [ -z $OS ]; then
echo "usage: build.sh OS ARCH" echo "usage: build.sh OS ARCH"
@@ -33,8 +34,7 @@ function build() {
fi fi
cd $ROOT"/../../EdgeNode/build" cd $ROOT"/../../EdgeNode/build"
echo "==============================" echo "=============================="
architects=("amd64" "386" "arm64" "mips64" "mips64le") for arch in "${NODE_ARCHITECTS[@]}"; do
for arch in "${architects[@]}"; do
if [ ! -f $ROOT"/../../EdgeNode/dist/edge-node-linux-${arch}-${TAG}-v${NodeVersion}.zip" ]; then if [ ! -f $ROOT"/../../EdgeNode/dist/edge-node-linux-${arch}-${TAG}-v${NodeVersion}.zip" ]; then
./build.sh linux $arch $TAG ./build.sh linux $arch $TAG
else else
@@ -45,7 +45,7 @@ function build() {
cd - cd -
rm -f $ROOT/deploy/*.zip rm -f $ROOT/deploy/*.zip
for arch in "${architects[@]}"; do for arch in "${NODE_ARCHITECTS[@]}"; do
cp $ROOT"/../../EdgeNode/dist/edge-node-linux-${arch}-${TAG}-v${NodeVersion}.zip" $ROOT/deploy/edge-node-linux-${arch}-v${NodeVersion}.zip cp $ROOT"/../../EdgeNode/dist/edge-node-linux-${arch}-${TAG}-v${NodeVersion}.zip" $ROOT/deploy/edge-node-linux-${arch}-v${NodeVersion}.zip
done done

3
go.mod
View File

@@ -4,6 +4,7 @@ go 1.15
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
require ( require (
github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
@@ -14,7 +15,7 @@ require (
github.com/go-sql-driver/mysql v1.5.0 github.com/go-sql-driver/mysql v1.5.0
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/golang/protobuf v1.5.2 github.com/golang/protobuf v1.5.2
github.com/iwind/TeaGo v0.0.0-20210809112119-a57ed0e84e34 github.com/iwind/TeaGo v0.0.0-20210831140440-a2a442471b13
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3
github.com/json-iterator/go v1.1.11 // indirect github.com/json-iterator/go v1.1.11 // indirect
github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/lionsoul2014/ip2region v2.2.0-release+incompatible

6
go.sum
View File

@@ -186,6 +186,12 @@ github.com/iwind/TeaGo v0.0.0-20210806054428-5534da0db9d1 h1:AZKkwTNEZYrpyv62zIk
github.com/iwind/TeaGo v0.0.0-20210806054428-5534da0db9d1/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= github.com/iwind/TeaGo v0.0.0-20210806054428-5534da0db9d1/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210809112119-a57ed0e84e34 h1:ZCNQXLiGF5Z1cV3Pi03zCWzwwjPfsI5XhcrNhTvCFIU= github.com/iwind/TeaGo v0.0.0-20210809112119-a57ed0e84e34 h1:ZCNQXLiGF5Z1cV3Pi03zCWzwwjPfsI5XhcrNhTvCFIU=
github.com/iwind/TeaGo v0.0.0-20210809112119-a57ed0e84e34/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= github.com/iwind/TeaGo v0.0.0-20210809112119-a57ed0e84e34/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210824034952-1a56ad7d0b5e h1:GDCU57lQD6W9u5KT2834MmK022FSeAbskb7H0p2eaJY=
github.com/iwind/TeaGo v0.0.0-20210824034952-1a56ad7d0b5e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210829020150-9c36d31301a5 h1:ybjIXGT3E/ZbfkRhIb903WMfLyt2Uv5p4niAqi8jwvM=
github.com/iwind/TeaGo v0.0.0-20210829020150-9c36d31301a5/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/TeaGo v0.0.0-20210831140440-a2a442471b13 h1:HuEJ5xJfujW1Q6rNDhOu5LQXEBB2qLPah3jYslT8Gz4=
github.com/iwind/TeaGo v0.0.0-20210831140440-a2a442471b13/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 h1:aBSonas7vFcgTj9u96/bWGILGv1ZbUSTLiOzcI1ZT6c= github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3 h1:aBSonas7vFcgTj9u96/bWGILGv1ZbUSTLiOzcI1ZT6c=
github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA= github.com/iwind/gosock v0.0.0-20210722083328-12b2d66abec3/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=

View File

@@ -1,7 +1,7 @@
package teaconst package teaconst
const ( const (
Version = "0.2.9" Version = "0.3.0"
ProductName = "Edge API" ProductName = "Edge API"
ProcessName = "edge-api" ProcessName = "edge-api"
@@ -18,9 +18,9 @@ const (
// 其他节点版本号,用来检测是否有需要升级的节点 // 其他节点版本号,用来检测是否有需要升级的节点
NodeVersion = "0.2.8" NodeVersion = "0.3.0"
UserNodeVersion = "0.0.10" UserNodeVersion = "0.0.10"
AuthorityNodeVersion = "0.0.2" AuthorityNodeVersion = "0.0.2"
MonitorNodeVersion = "0.0.2" MonitorNodeVersion = "0.0.3"
DNSNodeVersion = "0.1.0" DNSNodeVersion = "0.2.0"
) )

View File

@@ -3,5 +3,6 @@
package teaconst package teaconst
var ( var (
IsPlus = false IsPlus = false
MaxNodes int32 = 0
) )

View File

@@ -4,9 +4,10 @@ import (
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
) )
// 解析HTTP配置 // DecodeHTTP 解析HTTP配置
func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) { func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
if !IsNotNull(this.Http) { if !IsNotNull(this.Http) {
return nil, nil return nil, nil
@@ -25,8 +26,12 @@ func (this *APINode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
return config, nil return config, nil
} }
// 解析HTTPS配置 // DecodeHTTPS 解析HTTPS配置
func (this *APINode) DecodeHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig, error) { func (this *APINode) DecodeHTTPS(tx *dbs.Tx, cacheMap maps.Map) (*serverconfigs.HTTPSProtocolConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
if !IsNotNull(this.Https) { if !IsNotNull(this.Https) {
return nil, nil return nil, nil
} }
@@ -44,7 +49,7 @@ func (this *APINode) DecodeHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig
if config.SSLPolicyRef != nil { if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 { if policyId > 0 {
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId) sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,7 +67,7 @@ func (this *APINode) DecodeHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig
return config, nil return config, nil
} }
// 解析访问地址 // DecodeAccessAddrs 解析访问地址
func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) { func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
if !IsNotNull(this.AccessAddrs) { if !IsNotNull(this.AccessAddrs) {
return nil, nil return nil, nil
@@ -82,7 +87,7 @@ func (this *APINode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig,
return addrConfigs, nil return addrConfigs, nil
} }
// 解析访问地址,并返回字符串形式 // DecodeAccessAddrStrings 解析访问地址,并返回字符串形式
func (this *APINode) DecodeAccessAddrStrings() ([]string, error) { func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
addrs, err := this.DecodeAccessAddrs() addrs, err := this.DecodeAccessAddrs()
if err != nil { if err != nil {
@@ -95,7 +100,7 @@ func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
return result, nil return result, nil
} }
// 解析Rest HTTP配置 // DecodeRestHTTP 解析Rest HTTP配置
func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) { func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
if this.RestIsOn != 1 { if this.RestIsOn != 1 {
return nil, nil return nil, nil
@@ -117,8 +122,11 @@ func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error)
return config, nil return config, nil
} }
// 解析HTTPS配置 // DecodeRestHTTPS 解析HTTPS配置
func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolConfig, error) { func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx, cacheMap maps.Map) (*serverconfigs.HTTPSProtocolConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
if this.RestIsOn != 1 { if this.RestIsOn != 1 {
return nil, nil return nil, nil
} }
@@ -139,7 +147,7 @@ func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx) (*serverconfigs.HTTPSProtocolCo
if config.SSLPolicyRef != nil { if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 { if policyId > 0 {
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId) sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, policyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -176,3 +176,15 @@ func (this *DBNodeDAO) DecodePassword(password string) string {
} }
return string(encrypt.MagicKeyDecode(data)) return string(encrypt.MagicKeyDecode(data))
} }
// CheckNodeIsOn 检查节点是否已经启用
func (this *DBNodeDAO) CheckNodeIsOn(tx *dbs.Tx, nodeId int64) (bool, error) {
isOn, err := this.Query(tx).
Pk(nodeId).
Result("isOn").
FindIntCol(0)
if err != nil {
return false, err
}
return isOn == 1, nil
}

View File

@@ -3,10 +3,10 @@ package models
import ( import (
"fmt" "fmt"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"hash/crc32" "hash/crc32"
"regexp" "regexp"
@@ -193,7 +193,7 @@ func findHTTPAccessLogTable(db *dbs.DB, day string, force bool) (*httpAccessLogD
} }
// 创建表格 // 创建表格
_, err = db.Exec("CREATE TABLE `" + tableName + "` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'ID',`serverId` int(11) unsigned DEFAULT '0' COMMENT '服务ID',`nodeId` int(11) unsigned DEFAULT '0' COMMENT '节点ID',`status` int(3) unsigned DEFAULT '0' COMMENT '状态码',`createdAt` bigint(11) unsigned DEFAULT '0' COMMENT '创建时间', `content` json DEFAULT NULL COMMENT '日志内容', `requestId` varchar(128) DEFAULT NULL COMMENT '请求ID', `firewallPolicyId` int(11) unsigned DEFAULT '0' COMMENT 'WAF策略ID', `firewallRuleGroupId` int(11) unsigned DEFAULT '0' COMMENT 'WAF分组ID', `firewallRuleSetId` int(11) unsigned DEFAULT '0' COMMENT 'WAF集ID', `firewallRuleId` int(11) unsigned DEFAULT '0' COMMENT 'WAF规则ID', `remoteAddr` varchar(64) DEFAULT NULL COMMENT 'IP地址', `domain` varchar(128) DEFAULT NULL COMMENT '域名', PRIMARY KEY (`id`), KEY `serverId` (`serverId`), KEY `nodeId` (`nodeId`), KEY `serverId_status` (`serverId`,`status`), KEY `requestId` (`requestId`), KEY `firewallPolicyId` (`firewallPolicyId`), KEY `firewallRuleGroupId` (`firewallRuleGroupId`), KEY `firewallRuleSetId` (`firewallRuleSetId`), KEY `firewallRuleId` (`firewallRuleId`), KEY `remoteAddr` (`remoteAddr`), KEY `domain` (`domain`)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='访问日志';") _, err = db.Exec("CREATE TABLE `" + tableName + "` (\n `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'ID',\n `serverId` int(11) unsigned DEFAULT '0' COMMENT '服务ID',\n `nodeId` int(11) unsigned DEFAULT '0' COMMENT '节点ID',\n `status` int(3) unsigned DEFAULT '0' COMMENT '状态码',\n `createdAt` bigint(11) unsigned DEFAULT '0' COMMENT '创建时间',\n `content` json DEFAULT NULL COMMENT '日志内容',\n `requestId` varchar(128) DEFAULT NULL COMMENT '请求ID',\n `firewallPolicyId` int(11) unsigned DEFAULT '0' COMMENT 'WAF策略ID',\n `firewallRuleGroupId` int(11) unsigned DEFAULT '0' COMMENT 'WAF分组ID',\n `firewallRuleSetId` int(11) unsigned DEFAULT '0' COMMENT 'WAF集ID',\n `firewallRuleId` int(11) unsigned DEFAULT '0' COMMENT 'WAF规则ID',\n `remoteAddr` varchar(64) DEFAULT NULL COMMENT 'IP地址',\n `domain` varchar(128) DEFAULT NULL COMMENT '域名',\n `requestBody` blob COMMENT '请求内容',\n `responseBody` blob COMMENT '响应内容',\n PRIMARY KEY (`id`),\n KEY `serverId` (`serverId`),\n KEY `nodeId` (`nodeId`),\n KEY `serverId_status` (`serverId`,`status`),\n KEY `requestId` (`requestId`),\n KEY `firewallPolicyId` (`firewallPolicyId`),\n KEY `firewallRuleGroupId` (`firewallRuleGroupId`),\n KEY `firewallRuleSetId` (`firewallRuleSetId`),\n KEY `firewallRuleId` (`firewallRuleId`),\n KEY `remoteAddr` (`remoteAddr`),\n KEY `domain` (`domain`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='访问日志';")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -266,7 +266,7 @@ func (this *DBNodeInitializer) Start() {
// 初始运行 // 初始运行
err := this.loop() err := this.loop()
if err != nil { if err != nil {
logs.Println("[DB_NODE]" + err.Error()) remotelogs.Error("DB_NODE", err.Error())
} }
// 定时运行 // 定时运行
@@ -274,7 +274,7 @@ func (this *DBNodeInitializer) Start() {
for range ticker.C { for range ticker.C {
err := this.loop() err := this.loop()
if err != nil { if err != nil {
logs.Println("[DB_NODE]" + err.Error()) remotelogs.Error("DB_NODE", err.Error())
} }
} }
} }
@@ -300,7 +300,7 @@ func (this *DBNodeInitializer) loop() error {
delete(accessLogDBMapping, nodeId) delete(accessLogDBMapping, nodeId)
delete(httpAccessLogDAOMapping, nodeId) delete(httpAccessLogDAOMapping, nodeId)
delete(nsAccessLogDAOMapping, nodeId) delete(nsAccessLogDAOMapping, nodeId)
logs.Println("[DB_NODE]close db node '" + strconv.FormatInt(nodeId, 10) + "'") remotelogs.Error("DB_NODE", "close db node '"+strconv.FormatInt(nodeId, 10)+"'")
} }
} }
accessLogLocker.Unlock() accessLogLocker.Unlock()
@@ -321,7 +321,7 @@ func (this *DBNodeInitializer) loop() error {
// 检查配置是否有变化 // 检查配置是否有变化
oldConfig, err := db.Config() oldConfig, err := db.Config()
if err != nil { if err != nil {
logs.Println("[DB_NODE]read database old config failed: " + err.Error()) remotelogs.Error("DB_NODE", "read database old config failed: "+err.Error())
continue continue
} }
@@ -340,7 +340,7 @@ func (this *DBNodeInitializer) loop() error {
} }
db, err := dbs.NewInstanceFromConfig(config) db, err := dbs.NewInstanceFromConfig(config)
if err != nil { if err != nil {
logs.Println("[DB_NODE]initialize database config failed: " + err.Error()) remotelogs.Error("DB_NODE", "initialize database config failed: "+err.Error())
continue continue
} }
@@ -350,12 +350,12 @@ func (this *DBNodeInitializer) loop() error {
tableDef, err := findHTTPAccessLogTable(db, timeutil.Format("Ymd"), true) tableDef, err := findHTTPAccessLogTable(db, timeutil.Format("Ymd"), true)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "1050") { // 非表格已存在错误 if !strings.Contains(err.Error(), "1050") { // 非表格已存在错误
logs.Println("[DB_NODE]create first table in database node failed: " + err.Error()) remotelogs.Error("DB_NODE", "create first table in database node failed: "+err.Error())
// 创建节点日志 // 创建节点日志
createLogErr := SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleDatabase, nodeId, 0, 0, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix()) createLogErr := SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleDatabase, nodeId, 0, 0, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix())
if createLogErr != nil { if createLogErr != nil {
logs.Println("[NODE_LOG]" + createLogErr.Error()) remotelogs.Error("NODE_LOG", createLogErr.Error())
} }
continue continue
@@ -373,7 +373,7 @@ func (this *DBNodeInitializer) loop() error {
} }
err = daoObject.Init() err = daoObject.Init()
if err != nil { if err != nil {
logs.Println("[DB_NODE]initialize dao failed: " + err.Error()) remotelogs.Error("DB_NODE", "initialize dao failed: "+err.Error())
continue continue
} }
@@ -394,12 +394,12 @@ func (this *DBNodeInitializer) loop() error {
tableName, err := findNSAccessLogTable(db, timeutil.Format("Ymd"), false) tableName, err := findNSAccessLogTable(db, timeutil.Format("Ymd"), false)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "1050") { // 非表格已存在错误 if !strings.Contains(err.Error(), "1050") { // 非表格已存在错误
logs.Println("[DB_NODE]create first table in database node failed: " + err.Error()) remotelogs.Error("DB_NODE", "create first table in database node failed: "+err.Error())
// 创建节点日志 // 创建节点日志
createLogErr := SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleDatabase, nodeId, 0, 0, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix()) createLogErr := SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleDatabase, nodeId, 0, 0, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix())
if createLogErr != nil { if createLogErr != nil {
logs.Println("[NODE_LOG]" + createLogErr.Error()) remotelogs.Error("NODE_LOG", createLogErr.Error())
} }
continue continue
@@ -417,7 +417,7 @@ func (this *DBNodeInitializer) loop() error {
} }
err = daoObject.Init() err = daoObject.Init()
if err != nil { if err != nil {
logs.Println("[DB_NODE]initialize dao failed: " + err.Error()) remotelogs.Error("DB_NODE", "initialize dao failed: "+err.Error())
continue continue
} }

View File

@@ -58,14 +58,24 @@ func (this *DNSDomainDAO) DisableDNSDomain(tx *dbs.Tx, id int64) error {
} }
// FindEnabledDNSDomain 查找启用中的条目 // FindEnabledDNSDomain 查找启用中的条目
func (this *DNSDomainDAO) FindEnabledDNSDomain(tx *dbs.Tx, id int64) (*DNSDomain, error) { func (this *DNSDomainDAO) FindEnabledDNSDomain(tx *dbs.Tx, domainId int64, cacheMap maps.Map) (*DNSDomain, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":record:" + types.String(domainId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*DNSDomain), nil
}
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(domainId).
Attr("state", DNSDomainStateEnabled). Attr("state", DNSDomainStateEnabled).
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err
} }
cacheMap[cacheKey] = result
return result.(*DNSDomain), err return result.(*DNSDomain), err
} }
@@ -86,6 +96,7 @@ func (this *DNSDomainDAO) CreateDomain(tx *dbs.Tx, adminId int64, userId int64,
op.Name = name op.Name = name
op.State = DNSDomainStateEnabled op.State = DNSDomainStateEnabled
op.IsOn = true op.IsOn = true
op.IsUp = true
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -203,7 +214,7 @@ func (this *DNSDomainDAO) FindDomainRouteName(tx *dbs.Tx, domainId int64, routeC
// ExistAvailableDomains 判断是否有域名可选 // ExistAvailableDomains 判断是否有域名可选
func (this *DNSDomainDAO) ExistAvailableDomains(tx *dbs.Tx) (bool, error) { func (this *DNSDomainDAO) ExistAvailableDomains(tx *dbs.Tx) (bool, error) {
subQuery, err := SharedDNSProviderDAO.Query(tx). subQuery, err := SharedDNSProviderDAO.Query(tx).
Where("state=1"). // 这里要使用非变量 Where("state=1"). // 这里要使用非变量
ResultPk(). ResultPk().
AsSQL() AsSQL()
if err != nil { if err != nil {
@@ -247,3 +258,25 @@ func (this *DNSDomainDAO) ExistDomainRecord(tx *dbs.Tx, domainId int64, recordNa
Param("query", query.AsJSON()). Param("query", query.AsJSON()).
Exist() Exist()
} }
// FindEnabledDomainWithName 根据名称查找某个域名
func (this *DNSDomainDAO) FindEnabledDomainWithName(tx *dbs.Tx, providerId int64, domainName string) (*DNSDomain, error) {
one, err := this.Query(tx).
State(DNSDomainStateEnabled).
Attr("isOn", true).
Attr("providerId", providerId).
Attr("name", domainName).
Find()
if one != nil {
return one.(*DNSDomain), nil
}
return nil, err
}
// UpdateDomainIsUp 设置是否在线
func (this *DNSDomainDAO) UpdateDomainIsUp(tx *dbs.Tx, domainId int64, isUp bool) error {
return this.Query(tx).
Pk(domainId).
Set("isUp", isUp).
UpdateQuickly()
}

View File

@@ -1,6 +1,6 @@
package dns package dns
// 管理的域名 // DNSDomain 管理的域名
type DNSDomain struct { type DNSDomain struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
@@ -14,6 +14,7 @@ type DNSDomain struct {
Data string `field:"data"` // 原始数据信息 Data string `field:"data"` // 原始数据信息
Records string `field:"records"` // 所有解析记录 Records string `field:"records"` // 所有解析记录
Routes string `field:"routes"` // 线路数据 Routes string `field:"routes"` // 线路数据
IsUp uint8 `field:"isUp"` // 是否在线
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
} }
@@ -30,6 +31,7 @@ type DNSDomainOperator struct {
Data interface{} // 原始数据信息 Data interface{} // 原始数据信息
Records interface{} // 所有解析记录 Records interface{} // 所有解析记录
Routes interface{} // 线路数据 Routes interface{} // 线路数据
IsUp interface{} // 是否在线
State interface{} // 状态 State interface{} // 状态
} }

View File

@@ -36,7 +36,7 @@ func init() {
}) })
} }
// 启用条目 // EnableDNSProvider 启用条目
func (this *DNSProviderDAO) EnableDNSProvider(tx *dbs.Tx, id int64) error { func (this *DNSProviderDAO) EnableDNSProvider(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -45,7 +45,7 @@ func (this *DNSProviderDAO) EnableDNSProvider(tx *dbs.Tx, id int64) error {
return err return err
} }
// 禁用条目 // DisableDNSProvider 禁用条目
func (this *DNSProviderDAO) DisableDNSProvider(tx *dbs.Tx, id int64) error { func (this *DNSProviderDAO) DisableDNSProvider(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -54,7 +54,7 @@ func (this *DNSProviderDAO) DisableDNSProvider(tx *dbs.Tx, id int64) error {
return err return err
} }
// 查找启用中的条目 // FindEnabledDNSProvider 查找启用中的条目
func (this *DNSProviderDAO) FindEnabledDNSProvider(tx *dbs.Tx, id int64) (*DNSProvider, error) { func (this *DNSProviderDAO) FindEnabledDNSProvider(tx *dbs.Tx, id int64) (*DNSProvider, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(id).
@@ -66,7 +66,7 @@ func (this *DNSProviderDAO) FindEnabledDNSProvider(tx *dbs.Tx, id int64) (*DNSPr
return result.(*DNSProvider), err return result.(*DNSProvider), err
} }
// 创建服务商 // CreateDNSProvider 创建服务商
func (this *DNSProviderDAO) CreateDNSProvider(tx *dbs.Tx, adminId int64, userId int64, providerType string, name string, apiParamsJSON []byte) (int64, error) { func (this *DNSProviderDAO) CreateDNSProvider(tx *dbs.Tx, adminId int64, userId int64, providerType string, name string, apiParamsJSON []byte) (int64, error) {
op := NewDNSProviderOperator() op := NewDNSProviderOperator()
op.AdminId = adminId op.AdminId = adminId
@@ -84,7 +84,7 @@ func (this *DNSProviderDAO) CreateDNSProvider(tx *dbs.Tx, adminId int64, userId
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// 修改服务商 // UpdateDNSProvider 修改服务商
func (this *DNSProviderDAO) UpdateDNSProvider(tx *dbs.Tx, dnsProviderId int64, name string, apiParamsJSON []byte) error { func (this *DNSProviderDAO) UpdateDNSProvider(tx *dbs.Tx, dnsProviderId int64, name string, apiParamsJSON []byte) error {
if dnsProviderId <= 0 { if dnsProviderId <= 0 {
return errors.New("invalid dnsProviderId") return errors.New("invalid dnsProviderId")
@@ -106,16 +106,25 @@ func (this *DNSProviderDAO) UpdateDNSProvider(tx *dbs.Tx, dnsProviderId int64, n
return nil return nil
} }
// 计算服务商数量 // CountAllEnabledDNSProviders 计算服务商数量
func (this *DNSProviderDAO) CountAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64) (int64, error) { func (this *DNSProviderDAO) CountAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64, keyword string) (int64, error) {
return dbutils.NewQuery(tx, this, adminId, userId). var query = dbutils.NewQuery(tx, this, adminId, userId)
State(DNSProviderStateEnabled). if len(keyword) > 0 {
query.Where("(name LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
return query.State(DNSProviderStateEnabled).
Count() Count()
} }
// 列出单页服务商 // ListEnabledDNSProviders 列出单页服务商
func (this *DNSProviderDAO) ListEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64, offset int64, size int64) (result []*DNSProvider, err error) { func (this *DNSProviderDAO) ListEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64, keyword string, offset int64, size int64) (result []*DNSProvider, err error) {
_, err = dbutils.NewQuery(tx, this, adminId, userId). var query = dbutils.NewQuery(tx, this, adminId, userId)
if len(keyword) > 0 {
query.Where("(name LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
_, err = query.
State(DNSProviderStateEnabled). State(DNSProviderStateEnabled).
Offset(offset). Offset(offset).
Limit(size). Limit(size).
@@ -125,7 +134,7 @@ func (this *DNSProviderDAO) ListEnabledDNSProviders(tx *dbs.Tx, adminId int64, u
return return
} }
// 列出所有服务商 // FindAllEnabledDNSProviders 列出所有服务商
func (this *DNSProviderDAO) FindAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64) (result []*DNSProvider, err error) { func (this *DNSProviderDAO) FindAllEnabledDNSProviders(tx *dbs.Tx, adminId int64, userId int64) (result []*DNSProvider, err error) {
_, err = dbutils.NewQuery(tx, this, adminId, userId). _, err = dbutils.NewQuery(tx, this, adminId, userId).
State(DNSProviderStateEnabled). State(DNSProviderStateEnabled).
@@ -135,7 +144,7 @@ func (this *DNSProviderDAO) FindAllEnabledDNSProviders(tx *dbs.Tx, adminId int64
return return
} }
// 查询某个类型下的所有服务商 // FindAllEnabledDNSProvidersWithType 查询某个类型下的所有服务商
func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(tx *dbs.Tx, providerType string) (result []*DNSProvider, err error) { func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(tx *dbs.Tx, providerType string) (result []*DNSProvider, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
State(DNSProviderStateEnabled). State(DNSProviderStateEnabled).
@@ -146,7 +155,7 @@ func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(tx *dbs.Tx, provi
return return
} }
// 更新数据更新时间 // UpdateProviderDataUpdatedTime 更新数据更新时间
func (this *DNSProviderDAO) UpdateProviderDataUpdatedTime(tx *dbs.Tx, providerId int64) error { func (this *DNSProviderDAO) UpdateProviderDataUpdatedTime(tx *dbs.Tx, providerId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(providerId). Pk(providerId).

View File

@@ -0,0 +1,207 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package dnsutils
import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
)
// CheckClusterDNS 检查集群的DNS问题
// 藏这么深是避免package循环引用的问题
func CheckClusterDNS(tx *dbs.Tx, cluster *models.NodeCluster) (issues []*pb.DNSIssue, err error) {
clusterId := int64(cluster.Id)
domainId := int64(cluster.DnsDomainId)
// 检查域名
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, domainId, nil)
if err != nil {
return nil, err
}
if domain == nil {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "域名选择错误,需要重新选择",
Params: nil,
MustFix: true,
})
return
}
// Provider
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(domain.ProviderId))
if err != nil {
return nil, err
}
if provider == nil {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "域名服务商不可用,需要重新选择",
Params: nil,
MustFix: true,
})
return
}
paramsMap, err := provider.DecodeAPIParams()
if err != nil {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "域名服务商参数配置错误,需要重新配置",
Params: nil,
MustFix: true,
})
return
}
var dnsProvider = dnsclients.FindProvider(provider.Type)
if dnsProvider == nil {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "目前不支持\"" + provider.Type + "\"服务商,需要重新配置",
Params: nil,
MustFix: true,
})
return
}
err = dnsProvider.Auth(paramsMap)
if err != nil {
return
}
var defaultRoute = dnsProvider.DefaultRoute()
var hasDefaultRoute = len(defaultRoute) > 0
// 检查二级域名
if len(cluster.DnsName) == 0 {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "没有设置二级域名",
Params: nil,
MustFix: true,
})
return
}
// TODO 检查域名格式
// TODO 检查域名是否已解析
// 检查节点
nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, clusterId, true)
if err != nil {
return nil, err
}
// TODO 检查节点数量不能为0
for _, node := range nodes {
nodeId := int64(node.Id)
routeCodes, err := node.DNSRouteCodesForDomainId(domainId)
if err != nil {
return nil, err
}
if len(routeCodes) == 0 && !hasDefaultRoute {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "没有选择节点所属线路",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
MustFix: true,
})
continue
}
// 检查线路是否在已有线路中
for _, routeCode := range routeCodes {
routeOk, err := domain.ContainsRouteCode(routeCode)
if err != nil {
return nil, err
}
if !routeOk {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "线路已经失效,请重新选择",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
MustFix: true,
})
continue
}
}
// 检查IP地址
ipAddr, err := models.SharedNodeIPAddressDAO.FindFirstNodeAccessIPAddress(tx, nodeId, nodeconfigs.NodeRoleNode)
if err != nil {
return nil, err
}
if len(ipAddr) == 0 {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "没有设置IP地址",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
MustFix: true,
})
continue
}
// TODO 检查是否有解析记录
}
return
}
// FindDefaultDomainRoute 获取域名默认的线路
func FindDefaultDomainRoute(tx *dbs.Tx, domain *dns.DNSDomain) (string, error) {
if domain == nil {
return "", errors.New("can not find domain")
}
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(domain.ProviderId))
if err != nil {
return "", err
}
if provider == nil {
return "", errors.New("provider not found")
}
paramsMap, err := provider.DecodeAPIParams()
if err != nil {
return "", errors.New("decode provider params failed: " + err.Error())
}
var dnsProvider = dnsclients.FindProvider(provider.Type)
if dnsProvider == nil {
return "", errors.New("not supported provider type '" + provider.Type + "'")
}
err = dnsProvider.Auth(paramsMap)
if err != nil {
return "", err
}
return dnsProvider.DefaultRoute(), nil
}

View File

@@ -0,0 +1,29 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package dnsutils
import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestNodeClusterDAO_CheckClusterDNS(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, 34)
if err != nil {
t.Fatal(err)
}
if cluster == nil {
t.Log("cluster not found, skip the test")
return
}
issues, err := CheckClusterDNS(tx, cluster)
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(issues, t)
}

View File

@@ -238,6 +238,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, s
} }
if hasFirewallPolicy { if hasFirewallPolicy {
query.Where("firewallPolicyId>0") query.Where("firewallPolicyId>0")
query.UseIndex("firewallPolicyId")
} }
// keyword // keyword
@@ -254,6 +255,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, s
} }
} else { } else {
query.Attr("remoteAddr", ip) query.Attr("remoteAddr", ip)
query.UseIndex("remoteAddr")
} }
} else { } else {
query.Where("JSON_EXTRACT(content, '$.remoteAddr')=:ip1"). query.Where("JSON_EXTRACT(content, '$.remoteAddr')=:ip1").
@@ -269,6 +271,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, s
Param("host2", domain) Param("host2", domain)
} else { } else {
query.Attr("domain", domain) query.Attr("domain", domain)
query.UseIndex("domain")
} }
} else { } else {
query.Where("JSON_EXTRACT(content, '$.host')=:host1"). query.Where("JSON_EXTRACT(content, '$.host')=:host1").

View File

@@ -15,6 +15,8 @@ type HTTPAccessLog struct {
FirewallRuleId uint32 `field:"firewallRuleId"` // WAF规则ID FirewallRuleId uint32 `field:"firewallRuleId"` // WAF规则ID
RemoteAddr string `field:"remoteAddr"` // IP地址 RemoteAddr string `field:"remoteAddr"` // IP地址
Domain string `field:"domain"` // 域名 Domain string `field:"domain"` // 域名
RequestBody string `field:"requestBody"` // 请求内容
ResponseBody string `field:"responseBody"` // 响应内容
} }
type HTTPAccessLogOperator struct { type HTTPAccessLogOperator struct {
@@ -31,6 +33,8 @@ type HTTPAccessLogOperator struct {
FirewallRuleId interface{} // WAF规则ID FirewallRuleId interface{} // WAF规则ID
RemoteAddr interface{} // IP地址 RemoteAddr interface{} // IP地址
Domain interface{} // 域名 Domain interface{} // 域名
RequestBody interface{} // 请求内容
ResponseBody interface{} // 响应内容
} }
func NewHTTPAccessLogOperator() *HTTPAccessLogOperator { func NewHTTPAccessLogOperator() *HTTPAccessLogOperator {

View File

@@ -7,6 +7,8 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
) )
const ( const (
@@ -94,7 +96,16 @@ func (this *HTTPAuthPolicyDAO) UpdateHTTPAuthPolicy(tx *dbs.Tx, policyId int64,
} }
// ComposePolicyConfig 组合配置 // ComposePolicyConfig 组合配置
func (this *HTTPAuthPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*serverconfigs.HTTPAuthPolicy, error) { func (this *HTTPAuthPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64, cacheMap maps.Map) (*serverconfigs.HTTPAuthPolicy, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(policyId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.HTTPAuthPolicy), nil
}
policy, err := this.FindEnabledHTTPAuthPolicy(tx, policyId) policy, err := this.FindEnabledHTTPAuthPolicy(tx, policyId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -119,6 +130,8 @@ func (this *HTTPAuthPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (
} }
config.Params = params config.Params = params
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -155,6 +156,44 @@ func (this *HTTPCachePolicyDAO) CreateCachePolicy(tx *dbs.Tx, isOn bool, name st
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// CreateDefaultCachePolicy 创建默认的缓存策略
func (this *HTTPCachePolicyDAO) CreateDefaultCachePolicy(tx *dbs.Tx, name string) (int64, error) {
var capacity = &shared.SizeCapacity{
Count: 64,
Unit: shared.SizeCapacityUnitGB,
}
capacityJSON, err := capacity.AsJSON()
if err != nil {
return 0, err
}
var maxSize = &shared.SizeCapacity{
Count: 256,
Unit: shared.SizeCapacityUnitMB,
}
if err != nil {
return 0, err
}
maxSizeJSON, err := maxSize.AsJSON()
if err != nil {
return 0, err
}
var storageOptions = &serverconfigs.HTTPFileCacheStorage{
Dir: "/opt/cache",
}
storageOptionsJSON, err := json.Marshal(storageOptions)
if err != nil {
return 0, err
}
policyId, err := this.CreateCachePolicy(tx, true, "\""+name+"\"缓存策略", "默认创建的缓存策略", capacityJSON, 0, maxSizeJSON, serverconfigs.CachePolicyStorageFile, storageOptionsJSON)
if err != nil {
return 0, err
}
return policyId, nil
}
// UpdateCachePolicy 修改缓存策略 // UpdateCachePolicy 修改缓存策略
func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) error { func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) error {
if policyId <= 0 { if policyId <= 0 {
@@ -185,7 +224,16 @@ func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, is
} }
// ComposeCachePolicy 组合配置 // ComposeCachePolicy 组合配置
func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64) (*serverconfigs.HTTPCachePolicy, error) { func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, cacheMap maps.Map) (*serverconfigs.HTTPCachePolicy, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(policyId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.HTTPCachePolicy), nil
}
policy, err := this.FindEnabledHTTPCachePolicy(tx, policyId) policy, err := this.FindEnabledHTTPCachePolicy(tx, policyId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -243,6 +291,8 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64) (
config.CacheRefs = refs config.CacheRefs = refs
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }
@@ -284,7 +334,7 @@ func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(tx *dbs.Tx, keyword
cachePolicies := []*serverconfigs.HTTPCachePolicy{} cachePolicies := []*serverconfigs.HTTPCachePolicy{}
for _, policyId := range cachePolicyIds { for _, policyId := range cachePolicyIds {
cachePolicyConfig, err := this.ComposeCachePolicy(tx, policyId) cachePolicyConfig, err := this.ComposeCachePolicy(tx, policyId, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err) return nil, errors.Wrap(err)
} }

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -113,8 +114,73 @@ func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64
return types.Int64(op.Id), err return types.Int64(op.Id), err
} }
// CreateDefaultFirewallPolicy 创建默认的WAF策略
func (this *HTTPFirewallPolicyDAO) CreateDefaultFirewallPolicy(tx *dbs.Tx, name string) (int64, error) {
policyId, err := this.CreateFirewallPolicy(tx, 0, 0, true, "\""+name+"\"WAF策略", "默认创建的WAF策略", nil, nil)
if err != nil {
return 0, err
}
// 初始化
var groupCodes = []string{}
templatePolicy := firewallconfigs.HTTPFirewallTemplate()
for _, group := range templatePolicy.AllRuleGroups() {
groupCodes = append(groupCodes, group.Code)
}
inboundConfig := &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
outboundConfig := &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups {
isOn := lists.ContainsString(groupCodes, group.Code)
group.IsOn = isOn
groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return 0, err
}
inboundConfig.GroupRefs = append(inboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups {
isOn := lists.ContainsString(groupCodes, group.Code)
group.IsOn = isOn
groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return 0, err
}
outboundConfig.GroupRefs = append(outboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return 0, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return 0, err
}
err = this.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, inboundConfigJSON, outboundConfigJSON, false)
if err != nil {
return 0, err
}
return policyId, nil
}
// UpdateFirewallPolicyInboundAndOutbound 修改策略的Inbound和Outbound // UpdateFirewallPolicyInboundAndOutbound 修改策略的Inbound和Outbound
func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *dbs.Tx, policyId int64, inboundJSON []byte, outboundJSON []byte) error { func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *dbs.Tx, policyId int64, inboundJSON []byte, outboundJSON []byte, shouldNotify bool) error {
if policyId <= 0 { if policyId <= 0 {
return errors.New("invalid policyId") return errors.New("invalid policyId")
} }
@@ -135,7 +201,11 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *db
return err return err
} }
return this.NotifyUpdate(tx, policyId) if shouldNotify {
return this.NotifyUpdate(tx, policyId)
}
return nil
} }
// UpdateFirewallPolicyInbound 修改策略的Inbound // UpdateFirewallPolicyInbound 修改策略的Inbound
@@ -223,7 +293,16 @@ func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(tx *dbs.Tx, keywo
} }
// ComposeFirewallPolicy 组合策略配置 // ComposeFirewallPolicy 组合策略配置
func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId int64) (*firewallconfigs.HTTPFirewallPolicy, error) { func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId int64, cacheMap maps.Map) (*firewallconfigs.HTTPFirewallPolicy, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(policyId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*firewallconfigs.HTTPFirewallPolicy), nil
}
policy, err := this.FindEnabledHTTPFirewallPolicy(tx, policyId) policy, err := this.FindEnabledHTTPFirewallPolicy(tx, policyId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -304,6 +383,8 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
config.BlockOptions = blockAction config.BlockOptions = blockAction
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -131,7 +131,16 @@ func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name s
} }
// ComposeLocationConfig 组合配置 // ComposeLocationConfig 组合配置
func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64) (*serverconfigs.HTTPLocationConfig, error) { func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64, cacheMap maps.Map) (*serverconfigs.HTTPLocationConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(locationId)
var cacheConfig = cacheMap.Get(cacheKey)
if cacheConfig != nil {
return cacheConfig.(*serverconfigs.HTTPLocationConfig), nil
}
location, err := this.FindEnabledHTTPLocation(tx, locationId) location, err := this.FindEnabledHTTPLocation(tx, locationId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -151,7 +160,7 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64)
// web // web
if location.WebId > 0 { if location.WebId > 0 {
webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(location.WebId)) webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(location.WebId), cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -167,7 +176,7 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64)
} }
config.ReverseProxyRef = ref config.ReverseProxyRef = ref
if ref.ReverseProxyId > 0 { if ref.ReverseProxyId > 0 {
reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, ref.ReverseProxyId) reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, ref.ReverseProxyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -185,6 +194,8 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64)
config.Conds = conds config.Conds = conds
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }
@@ -248,13 +259,13 @@ func (this *HTTPLocationDAO) UpdateLocationWeb(tx *dbs.Tx, locationId int64, web
} }
// ConvertLocationRefs 转换引用为配置 // ConvertLocationRefs 转换引用为配置
func (this *HTTPLocationDAO) ConvertLocationRefs(tx *dbs.Tx, refs []*serverconfigs.HTTPLocationRef) (locations []*serverconfigs.HTTPLocationConfig, err error) { func (this *HTTPLocationDAO) ConvertLocationRefs(tx *dbs.Tx, refs []*serverconfigs.HTTPLocationRef, cacheMap maps.Map) (locations []*serverconfigs.HTTPLocationConfig, err error) {
for _, ref := range refs { for _, ref := range refs {
config, err := this.ComposeLocationConfig(tx, ref.LocationId) config, err := this.ComposeLocationConfig(tx, ref.LocationId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
children, err := this.ConvertLocationRefs(tx, ref.Children) children, err := this.ConvertLocationRefs(tx, ref.Children, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -286,7 +297,6 @@ func (this *HTTPLocationDAO) FindEnabledLocationIdWithReverseProxyId(tx *dbs.Tx,
FindInt64Col(0) FindInt64Col(0)
} }
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *HTTPLocationDAO) NotifyUpdate(tx *dbs.Tx, locationId int64) error { func (this *HTTPLocationDAO) NotifyUpdate(tx *dbs.Tx, locationId int64) error {
webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithLocationId(tx, locationId) webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithLocationId(tx, locationId)

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -36,12 +37,12 @@ func init() {
}) })
} }
// 初始化 // Init 初始化
func (this *HTTPPageDAO) Init() { func (this *HTTPPageDAO) Init() {
_ = this.DAOObject.Init() _ = this.DAOObject.Init()
} }
// 启用条目 // EnableHTTPPage 启用条目
func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, pageId int64) error { func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, pageId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(pageId). Pk(pageId).
@@ -53,7 +54,7 @@ func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, pageId int64) error {
return this.NotifyUpdate(tx, pageId) return this.NotifyUpdate(tx, pageId)
} }
// 禁用条目 // DisableHTTPPage 禁用条目
func (this *HTTPPageDAO) DisableHTTPPage(tx *dbs.Tx, id int64) error { func (this *HTTPPageDAO) DisableHTTPPage(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -62,7 +63,7 @@ func (this *HTTPPageDAO) DisableHTTPPage(tx *dbs.Tx, id int64) error {
return err return err
} }
// 查找启用中的条目 // FindEnabledHTTPPage 查找启用中的条目
func (this *HTTPPageDAO) FindEnabledHTTPPage(tx *dbs.Tx, id int64) (*HTTPPage, error) { func (this *HTTPPageDAO) FindEnabledHTTPPage(tx *dbs.Tx, id int64) (*HTTPPage, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(id).
@@ -74,7 +75,7 @@ func (this *HTTPPageDAO) FindEnabledHTTPPage(tx *dbs.Tx, id int64) (*HTTPPage, e
return result.(*HTTPPage), err return result.(*HTTPPage), err
} }
// 创建Page // CreatePage 创建Page
func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, statusList []string, url string, newStatus int) (pageId int64, err error) { func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, statusList []string, url string, newStatus int) (pageId int64, err error) {
op := NewHTTPPageOperator() op := NewHTTPPageOperator()
op.IsOn = true op.IsOn = true
@@ -97,7 +98,7 @@ func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, statusList []string, url string,
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// 修改Page // UpdatePage 修改Page
func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []string, url string, newStatus int) error { func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []string, url string, newStatus int) error {
if pageId <= 0 { if pageId <= 0 {
return errors.New("invalid pageId") return errors.New("invalid pageId")
@@ -126,8 +127,17 @@ func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []strin
return this.NotifyUpdate(tx, pageId) return this.NotifyUpdate(tx, pageId)
} }
// 组合配置 // ComposePageConfig 组合配置
func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64) (*serverconfigs.HTTPPageConfig, error) { func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap maps.Map) (*serverconfigs.HTTPPageConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(pageId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.HTTPPageConfig), nil
}
page, err := this.FindEnabledHTTPPage(tx, pageId) page, err := this.FindEnabledHTTPPage(tx, pageId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -154,10 +164,12 @@ func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64) (*servercon
} }
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }
// 通知更新 // NotifyUpdate 通知更新
func (this *HTTPPageDAO) NotifyUpdate(tx *dbs.Tx, pageId int64) error { func (this *HTTPPageDAO) NotifyUpdate(tx *dbs.Tx, pageId int64) error {
webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithPageId(tx, pageId) webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithPageId(tx, pageId)
if err != nil { if err != nil {

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -76,7 +77,16 @@ func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(tx *dbs.Tx, id int64)
} }
// ComposeRewriteRule 构造配置 // ComposeRewriteRule 构造配置
func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int64) (*serverconfigs.HTTPRewriteRule, error) { func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int64, cacheMap maps.Map) (*serverconfigs.HTTPRewriteRule, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(rewriteRuleId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.HTTPRewriteRule), nil
}
rule, err := this.FindEnabledHTTPRewriteRule(tx, rewriteRuleId) rule, err := this.FindEnabledHTTPRewriteRule(tx, rewriteRuleId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -105,6 +115,9 @@ func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int
} }
config.Conds = conds config.Conds = conds
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -75,7 +75,16 @@ func (this *HTTPWebDAO) FindEnabledHTTPWeb(tx *dbs.Tx, id int64) (*HTTPWeb, erro
} }
// ComposeWebConfig 组合配置 // ComposeWebConfig 组合配置
func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfigs.HTTPWebConfig, error) { func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap maps.Map) (*serverconfigs.HTTPWebConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(webId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.HTTPWebConfig), nil
}
web, err := SharedHTTPWebDAO.FindEnabledHTTPWeb(tx, webId) web, err := SharedHTTPWebDAO.FindEnabledHTTPWeb(tx, webId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -181,7 +190,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
return nil, err return nil, err
} }
for index, page := range pages { for index, page := range pages {
pageConfig, err := SharedHTTPPageDAO.ComposePageConfig(tx, page.Id) pageConfig, err := SharedHTTPPageDAO.ComposePageConfig(tx, page.Id, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -235,7 +244,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
// 自定义防火墙设置 // 自定义防火墙设置
if firewallRef.FirewallPolicyId > 0 { if firewallRef.FirewallPolicyId > 0 {
firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId) firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -257,7 +266,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
if len(refs) > 0 { if len(refs) > 0 {
config.LocationRefs = refs config.LocationRefs = refs
locations, err := SharedHTTPLocationDAO.ConvertLocationRefs(tx, refs) locations, err := SharedHTTPLocationDAO.ConvertLocationRefs(tx, refs, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -302,7 +311,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
return nil, err return nil, err
} }
for _, ref := range refs { for _, ref := range refs {
rewriteRule, err := SharedHTTPRewriteRuleDAO.ComposeRewriteRule(tx, ref.RewriteRuleId) rewriteRule, err := SharedHTTPRewriteRuleDAO.ComposeRewriteRule(tx, ref.RewriteRuleId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -356,7 +365,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
} }
var newRefs []*serverconfigs.HTTPAuthPolicyRef var newRefs []*serverconfigs.HTTPAuthPolicyRef
for _, ref := range authConfig.PolicyRefs { for _, ref := range authConfig.PolicyRefs {
policyConfig, err := SharedHTTPAuthPolicyDAO.ComposePolicyConfig(tx, ref.AuthPolicyId) policyConfig, err := SharedHTTPAuthPolicyDAO.ComposePolicyConfig(tx, ref.AuthPolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -368,6 +377,8 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64) (*serverconfig
config.Auth = authConfig config.Auth = authConfig
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -40,9 +40,14 @@ const (
MessageTypeServerNamesAuditingFailed MessageType = "ServerNamesAuditingFailed" // 服务域名审核失败 MessageTypeServerNamesAuditingFailed MessageType = "ServerNamesAuditingFailed" // 服务域名审核失败
MessageTypeThresholdSatisfied MessageType = "ThresholdSatisfied" // 满足阈值 MessageTypeThresholdSatisfied MessageType = "ThresholdSatisfied" // 满足阈值
MessageTypeFirewallEvent MessageType = "FirewallEvent" // 防火墙事件 MessageTypeFirewallEvent MessageType = "FirewallEvent" // 防火墙事件
MessageTypeIPAddrUp MessageType = "IPAddrUp" // IP地址上线
MessageTypeIPAddrDown MessageType = "IPAddrDown" // IP地址下线
MessageTypeNSNodeInactive MessageType = "NSNodeInactive" // 边缘节点不活跃 MessageTypeNSNodeInactive MessageType = "NSNodeInactive" // NS节点不活跃
MessageTypeNSNodeActive MessageType = "NSNodeActive" // 边缘节点活跃 MessageTypeNSNodeActive MessageType = "NSNodeActive" // NS节点活跃
MessageTypeReportNodeInactive MessageType = "ReportNodeInactive" // 区域监控节点节点不活跃
MessageTypeReportNodeActive MessageType = "ReportNodeActive" // 区域监控节点活跃
) )
type MessageDAO dbs.DAO type MessageDAO dbs.DAO
@@ -104,11 +109,7 @@ func (this *MessageDAO) CreateClusterMessage(tx *dbs.Tx, role string, clusterId
} }
// 发送给媒介接收人 // 发送给媒介接收人
err = SharedMessageTaskDAO.CreateMessageTasks(tx, MessageTaskTarget{ err = SharedMessageTaskDAO.CreateMessageTasks(tx, role, 0, 0, 0, messageType, subject, body)
ClusterId: clusterId,
NodeId: 0,
ServerId: 0,
}, messageType, subject, body)
if err != nil { if err != nil {
return err return err
} }
@@ -136,29 +137,10 @@ func (this *MessageDAO) CreateNodeMessage(tx *dbs.Tx, role string, clusterId int
return err return err
} }
// TODO 目前只支持边缘节点发送消息将来要支持NS节点 // 发送给媒介接收人 - 集群
if role == nodeconfigs.NodeRoleNode { err = SharedMessageTaskDAO.CreateMessageTasks(tx, role, clusterId, nodeId, 0, messageType, subject, body)
// 发送给媒介接收人 - 集群 if err != nil {
err = SharedMessageTaskDAO.CreateMessageTasks(tx, MessageTaskTarget{ return err
ClusterId: clusterId,
NodeId: 0,
ServerId: 0,
}, messageType, subject, body)
if err != nil {
return err
}
// 发送给媒介接收人 - 节点
if nodeId > 0 {
err = SharedMessageTaskDAO.CreateMessageTasks(tx, MessageTaskTarget{
ClusterId: clusterId,
NodeId: nodeId,
ServerId: 0,
}, messageType, subject, body)
if err != nil {
return err
}
}
} }
return nil return nil

View File

@@ -7,6 +7,7 @@ import (
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
) )
const ( const (
@@ -35,7 +36,7 @@ func init() {
}) })
} }
// 启用条目 // EnableMessageMediaInstance 启用条目
func (this *MessageMediaInstanceDAO) EnableMessageMediaInstance(tx *dbs.Tx, id int64) error { func (this *MessageMediaInstanceDAO) EnableMessageMediaInstance(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -44,7 +45,7 @@ func (this *MessageMediaInstanceDAO) EnableMessageMediaInstance(tx *dbs.Tx, id i
return err return err
} }
// 禁用条目 // DisableMessageMediaInstance 禁用条目
func (this *MessageMediaInstanceDAO) DisableMessageMediaInstance(tx *dbs.Tx, id int64) error { func (this *MessageMediaInstanceDAO) DisableMessageMediaInstance(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -53,20 +54,32 @@ func (this *MessageMediaInstanceDAO) DisableMessageMediaInstance(tx *dbs.Tx, id
return err return err
} }
// 查找启用中的条目 // FindEnabledMessageMediaInstance 查找启用中的条目
func (this *MessageMediaInstanceDAO) FindEnabledMessageMediaInstance(tx *dbs.Tx, id int64) (*MessageMediaInstance, error) { func (this *MessageMediaInstanceDAO) FindEnabledMessageMediaInstance(tx *dbs.Tx, instanceId int64, cacheMap maps.Map) (*MessageMediaInstance, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":record:" + types.String(instanceId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*MessageMediaInstance), nil
}
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(instanceId).
Attr("state", MessageMediaInstanceStateEnabled). Attr("state", MessageMediaInstanceStateEnabled).
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err
} }
cacheMap[cacheKey] = result
return result.(*MessageMediaInstance), err return result.(*MessageMediaInstance), err
} }
// 创建媒介实例 // CreateMediaInstance 创建媒介实例
func (this *MessageMediaInstanceDAO) CreateMediaInstance(tx *dbs.Tx, name string, mediaType string, params maps.Map, description string) (int64, error) { func (this *MessageMediaInstanceDAO) CreateMediaInstance(tx *dbs.Tx, name string, mediaType string, params maps.Map, description string, rateJSON []byte, hashLifeSeconds int32) (int64, error) {
op := NewMessageMediaInstanceOperator() op := NewMessageMediaInstanceOperator()
op.Name = name op.Name = name
op.MediaType = mediaType op.MediaType = mediaType
@@ -83,13 +96,18 @@ func (this *MessageMediaInstanceDAO) CreateMediaInstance(tx *dbs.Tx, name string
op.Description = description op.Description = description
if len(rateJSON) > 0 {
op.Rate = rateJSON
}
op.HashLife = hashLifeSeconds
op.IsOn = true op.IsOn = true
op.State = MessageMediaInstanceStateEnabled op.State = MessageMediaInstanceStateEnabled
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
} }
// 修改媒介实例 // UpdateMediaInstance 修改媒介实例
func (this *MessageMediaInstanceDAO) UpdateMediaInstance(tx *dbs.Tx, instanceId int64, name string, mediaType string, params maps.Map, description string, isOn bool) error { func (this *MessageMediaInstanceDAO) UpdateMediaInstance(tx *dbs.Tx, instanceId int64, name string, mediaType string, params maps.Map, description string, rateJSON []byte, hashLifeSeconds int32, isOn bool) error {
if instanceId <= 0 { if instanceId <= 0 {
return errors.New("invalid instanceId") return errors.New("invalid instanceId")
} }
@@ -109,12 +127,18 @@ func (this *MessageMediaInstanceDAO) UpdateMediaInstance(tx *dbs.Tx, instanceId
} }
op.Params = paramsJSON op.Params = paramsJSON
if len(rateJSON) > 0 {
op.Rate = rateJSON
}
op.HashLife = hashLifeSeconds
op.Description = description op.Description = description
op.IsOn = isOn op.IsOn = isOn
return this.Save(tx, op) return this.Save(tx, op)
} }
// 计算接收人数量 // CountAllEnabledMediaInstances 计算接收人数量
func (this *MessageMediaInstanceDAO) CountAllEnabledMediaInstances(tx *dbs.Tx, mediaType string, keyword string) (int64, error) { func (this *MessageMediaInstanceDAO) CountAllEnabledMediaInstances(tx *dbs.Tx, mediaType string, keyword string) (int64, error) {
query := this.Query(tx) query := this.Query(tx)
if len(mediaType) > 0 { if len(mediaType) > 0 {
@@ -130,7 +154,7 @@ func (this *MessageMediaInstanceDAO) CountAllEnabledMediaInstances(tx *dbs.Tx, m
Count() Count()
} }
// 列出单页接收人 // ListAllEnabledMediaInstances 列出单页接收人
func (this *MessageMediaInstanceDAO) ListAllEnabledMediaInstances(tx *dbs.Tx, mediaType string, keyword string, offset int64, size int64) (result []*MessageMediaInstance, err error) { func (this *MessageMediaInstanceDAO) ListAllEnabledMediaInstances(tx *dbs.Tx, mediaType string, keyword string, offset int64, size int64) (result []*MessageMediaInstance, err error) {
query := this.Query(tx) query := this.Query(tx)
if len(mediaType) > 0 { if len(mediaType) > 0 {
@@ -150,3 +174,15 @@ func (this *MessageMediaInstanceDAO) ListAllEnabledMediaInstances(tx *dbs.Tx, me
FindAll() FindAll()
return return
} }
// FindInstanceHashLifeSeconds 获取单个实例的HashLife
func (this *MessageMediaInstanceDAO) FindInstanceHashLifeSeconds(tx *dbs.Tx, instanceId int64) (int32, error) {
hashLife, err := this.Query(tx).
Pk(instanceId).
Result("hashLife").
FindIntCol(0)
if err != nil {
return 0, err
}
return types.Int32(hashLife), nil
}

View File

@@ -1,6 +1,6 @@
package models package models
// 消息媒介接收人 // MessageMediaInstance 消息媒介接收人
type MessageMediaInstance struct { type MessageMediaInstance struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
@@ -8,7 +8,9 @@ type MessageMediaInstance struct {
MediaType string `field:"mediaType"` // 媒介类型 MediaType string `field:"mediaType"` // 媒介类型
Params string `field:"params"` // 媒介参数 Params string `field:"params"` // 媒介参数
Description string `field:"description"` // 备注 Description string `field:"description"` // 备注
Rate string `field:"rate"` // 发送频率
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
HashLife int32 `field:"hashLife"` // HASH有效期
} }
type MessageMediaInstanceOperator struct { type MessageMediaInstanceOperator struct {
@@ -18,7 +20,9 @@ type MessageMediaInstanceOperator struct {
MediaType interface{} // 媒介类型 MediaType interface{} // 媒介类型
Params interface{} // 媒介参数 Params interface{} // 媒介参数
Description interface{} // 备注 Description interface{} // 备注
Rate interface{} // 发送频率
State interface{} // 状态 State interface{} // 状态
HashLife interface{} // HASH有效期
} }
func NewMessageMediaInstanceOperator() *MessageMediaInstanceOperator { func NewMessageMediaInstanceOperator() *MessageMediaInstanceOperator {

View File

@@ -75,11 +75,12 @@ func (this *MessageReceiverDAO) DisableReceivers(tx *dbs.Tx, clusterId int64, no
} }
// CreateReceiver 创建接收人 // CreateReceiver 创建接收人
func (this *MessageReceiverDAO) CreateReceiver(tx *dbs.Tx, target MessageTaskTarget, messageType MessageType, params maps.Map, recipientId int64, recipientGroupId int64) (int64, error) { func (this *MessageReceiverDAO) CreateReceiver(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType MessageType, params maps.Map, recipientId int64, recipientGroupId int64) (int64, error) {
op := NewMessageReceiverOperator() op := NewMessageReceiverOperator()
op.ClusterId = target.ClusterId op.Role = role
op.NodeId = target.NodeId op.ClusterId = clusterId
op.ServerId = target.ServerId op.NodeId = nodeId
op.ServerId = serverId
op.Type = messageType op.Type = messageType
if params == nil { if params == nil {
@@ -98,63 +99,120 @@ func (this *MessageReceiverDAO) CreateReceiver(tx *dbs.Tx, target MessageTaskTar
} }
// FindAllEnabledReceivers 查询接收人 // FindAllEnabledReceivers 查询接收人
func (this *MessageReceiverDAO) FindAllEnabledReceivers(tx *dbs.Tx, target MessageTaskTarget, messageType string) (result []*MessageReceiver, err error) { func (this *MessageReceiverDAO) FindAllEnabledReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (result []*MessageReceiver, err error) {
query := this.Query(tx) query := this.Query(tx)
if len(messageType) > 0 { if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的 query.Attr("type", []string{"*", messageType}) // *表示所有的
} }
_, err = query. _, err = query.
Attr("clusterId", target.ClusterId). Attr("role", role).
Attr("nodeId", target.NodeId). Attr("clusterId", clusterId).
Attr("serverId", target.ServerId). Attr("nodeId", nodeId).
Attr("serverId", serverId).
State(MessageReceiverStateEnabled). State(MessageReceiverStateEnabled).
AscPk(). AscPk().
Slice(&result). Slice(&result).
FindAll() FindAll()
if err != nil {
return nil, err
}
if len(result) == 0 {
// 去掉类型再试试
query := this.Query(tx)
_, err = query.
Attr("clusterId", target.ClusterId).
Attr("nodeId", target.NodeId).
Attr("serverId", target.ServerId).
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
if err != nil {
return nil, err
}
// 去掉服务和节点再试试
if len(result) == 0 {
query := this.Query(tx)
_, err = query.
Attr("clusterId", target.ClusterId).
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
}
}
return return
} }
// CountAllEnabledReceivers 计算接收人数量 // CountAllEnabledReceivers 计算接收人数量
func (this *MessageReceiverDAO) CountAllEnabledReceivers(tx *dbs.Tx, target MessageTaskTarget, messageType string) (int64, error) { func (this *MessageReceiverDAO) CountAllEnabledReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (int64, error) {
query := this.Query(tx) query := this.Query(tx)
if len(messageType) > 0 { if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的 query.Attr("type", []string{"*", messageType}) // *表示所有的
} }
return query. return query.
Attr("clusterId", target.ClusterId). Attr("role", role).
Attr("nodeId", target.NodeId). Attr("clusterId", clusterId).
Attr("serverId", target.ServerId). Attr("nodeId", nodeId).
Attr("serverId", serverId).
State(MessageReceiverStateEnabled). State(MessageReceiverStateEnabled).
Count() Count()
} }
// FindEnabledBestFitReceivers 查询最适合的接收人
func (this *MessageReceiverDAO) FindEnabledBestFitReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (result []*MessageReceiver, err error) {
// serverId优先
query := this.Query(tx)
if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的
}
if len(role) > 0 {
query.Attr("role", role)
}
if serverId > 0 {
query.Attr("serverId", serverId)
} else if nodeId > 0 {
query.Attr("nodeId", nodeId)
} else if clusterId > 0 {
query.Attr("clusterId", clusterId)
}
_, err = query.
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
if err != nil || len(result) > 0 {
return
}
// nodeId优先
if serverId > 0 && nodeId > 0 {
query = this.Query(tx)
if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的
}
if len(role) > 0 {
query.Attr("role", role)
}
query.Attr("nodeId", nodeId)
_, err = query.
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
if err != nil || len(result) > 0 {
return
}
}
// clusterId优先
if (serverId > 0 || nodeId > 0) && clusterId > 0 {
query = this.Query(tx)
if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的
}
if len(role) > 0 {
query.Attr("role", role)
}
query.Attr("clusterId", clusterId)
_, err = query.
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
if err != nil || len(result) > 0 {
return
}
}
// 去掉集群ID
query = this.Query(tx)
if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的
}
if len(role) > 0 {
query.Attr("role", role)
}
_, err = query.
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
if err != nil || len(result) > 0 {
return
}
return
}

View File

@@ -1,6 +1,30 @@
package models package models
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"testing"
) )
func TestMessageReceiverDAO_FindEnabledBestFitReceivers(t *testing.T) {
var tx *dbs.Tx
{
receivers, err := NewMessageReceiverDAO().FindEnabledBestFitReceivers(tx, nodeconfigs.NodeRoleNode, 18, 1, 2, "*")
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(receivers, t)
}
{
receivers, err := NewMessageReceiverDAO().FindEnabledBestFitReceivers(tx, nodeconfigs.NodeRoleNode, 30, 1, 2, "*")
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(receivers, t)
}
}

View File

@@ -3,6 +3,7 @@ package models
// MessageReceiver 消息通知接收人 // MessageReceiver 消息通知接收人
type MessageReceiver struct { type MessageReceiver struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
Role string `field:"role"` // 节点角色
ClusterId uint32 `field:"clusterId"` // 集群ID ClusterId uint32 `field:"clusterId"` // 集群ID
NodeId uint32 `field:"nodeId"` // 节点ID NodeId uint32 `field:"nodeId"` // 节点ID
ServerId uint32 `field:"serverId"` // 服务ID ServerId uint32 `field:"serverId"` // 服务ID
@@ -15,6 +16,7 @@ type MessageReceiver struct {
type MessageReceiverOperator struct { type MessageReceiverOperator struct {
Id interface{} // ID Id interface{} // ID
Role interface{} // 节点角色
ClusterId interface{} // 集群ID ClusterId interface{} // 集群ID
NodeId interface{} // 节点ID NodeId interface{} // 节点ID
ServerId interface{} // 服务ID ServerId interface{} // 服务ID

View File

@@ -7,6 +7,9 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"regexp"
) )
const ( const (
@@ -54,19 +57,32 @@ func (this *MessageRecipientDAO) DisableMessageRecipient(tx *dbs.Tx, id int64) e
} }
// FindEnabledMessageRecipient 查找启用中的条目 // FindEnabledMessageRecipient 查找启用中的条目
func (this *MessageRecipientDAO) FindEnabledMessageRecipient(tx *dbs.Tx, id int64) (*MessageRecipient, error) { func (this *MessageRecipientDAO) FindEnabledMessageRecipient(tx *dbs.Tx, recipientId int64, cacheMap maps.Map,
) (*MessageRecipient, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":record:" + types.String(recipientId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*MessageRecipient), nil
}
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(recipientId).
Attr("state", MessageRecipientStateEnabled). Attr("state", MessageRecipientStateEnabled).
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err
} }
cacheMap[cacheKey] = result
return result.(*MessageRecipient), err return result.(*MessageRecipient), err
} }
// CreateRecipient 创建接收人 // CreateRecipient 创建接收人
func (this *MessageRecipientDAO) CreateRecipient(tx *dbs.Tx, adminId int64, instanceId int64, user string, groupIds []int64, description string) (int64, error) { func (this *MessageRecipientDAO) CreateRecipient(tx *dbs.Tx, adminId int64, instanceId int64, user string, groupIds []int64, description string, timeFrom string, timeTo string) (int64, error) {
op := NewMessageRecipientOperator() op := NewMessageRecipientOperator()
op.AdminId = adminId op.AdminId = adminId
op.InstanceId = instanceId op.InstanceId = instanceId
@@ -83,13 +99,22 @@ func (this *MessageRecipientDAO) CreateRecipient(tx *dbs.Tx, adminId int64, inst
} }
op.GroupIds = groupIdsJSON op.GroupIds = groupIdsJSON
// 判断格式
var timeReg = regexp.MustCompile(`^\d+:\d+:\d+$`)
if timeReg.MatchString(timeFrom) {
op.TimeFrom = timeFrom
}
if timeReg.MatchString(timeTo) {
op.TimeTo = timeTo
}
op.IsOn = true op.IsOn = true
op.State = MessageRecipientStateEnabled op.State = MessageRecipientStateEnabled
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
} }
// UpdateRecipient 修改接收人 // UpdateRecipient 修改接收人
func (this *MessageRecipientDAO) UpdateRecipient(tx *dbs.Tx, recipientId int64, adminId int64, instanceId int64, user string, groupIds []int64, description string, isOn bool) error { func (this *MessageRecipientDAO) UpdateRecipient(tx *dbs.Tx, recipientId int64, adminId int64, instanceId int64, user string, groupIds []int64, description string, timeFrom string, timeTo string, isOn bool) error {
if recipientId <= 0 { if recipientId <= 0 {
return errors.New("invalid recipientId") return errors.New("invalid recipientId")
} }
@@ -111,6 +136,20 @@ func (this *MessageRecipientDAO) UpdateRecipient(tx *dbs.Tx, recipientId int64,
op.GroupIds = groupIdsJSON op.GroupIds = groupIdsJSON
op.Description = description op.Description = description
// 判断格式
var timeReg = regexp.MustCompile(`^\d+:\d+:\d+$`)
if timeReg.MatchString(timeFrom) {
op.TimeFrom = timeFrom
} else {
op.TimeFrom = dbs.SQL("NULL")
}
if timeReg.MatchString(timeTo) {
op.TimeTo = timeTo
} else {
op.TimeTo = dbs.SQL("NULL")
}
op.IsOn = isOn op.IsOn = isOn
return this.Save(tx, op) return this.Save(tx, op)
} }
@@ -187,3 +226,11 @@ func (this *MessageRecipientDAO) FindAllEnabledAndOnRecipientIdsWithGroup(tx *db
} }
return result, nil return result, nil
} }
// FindRecipientInstanceId 查找接收人的媒介
func (this *MessageRecipientDAO) FindRecipientInstanceId(tx *dbs.Tx, recipientId int64) (int64, error) {
return this.Query(tx).
Pk(recipientId).
Result("instanceId").
FindInt64Col(0)
}

View File

@@ -1,6 +1,6 @@
package models package models
// 消息媒介接收人 // MessageRecipient 消息媒介接收人
type MessageRecipient struct { type MessageRecipient struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
@@ -9,6 +9,8 @@ type MessageRecipient struct {
User string `field:"user"` // 接收人信息 User string `field:"user"` // 接收人信息
GroupIds string `field:"groupIds"` // 分组ID GroupIds string `field:"groupIds"` // 分组ID
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
TimeFrom string `field:"timeFrom"` // 开始时间
TimeTo string `field:"timeTo"` // 结束时间
Description string `field:"description"` // 备注 Description string `field:"description"` // 备注
} }
@@ -20,6 +22,8 @@ type MessageRecipientOperator struct {
User interface{} // 接收人信息 User interface{} // 接收人信息
GroupIds interface{} // 分组ID GroupIds interface{} // 分组ID
State interface{} // 状态 State interface{} // 状态
TimeFrom interface{} // 开始时间
TimeTo interface{} // 结束时间
Description interface{} // 备注 Description interface{} // 备注
} }

View File

@@ -1,11 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package models
// MessageTaskTarget 消息接收对象
// 每个字段不一定都有值
type MessageTaskTarget struct {
ClusterId int64 // 集群ID
NodeId int64 // 节点ID
ServerId int64 // 服务ID
}

View File

@@ -2,9 +2,15 @@ package models
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
timeutil "github.com/iwind/TeaGo/utils/time"
"time" "time"
) )
@@ -35,6 +41,21 @@ func NewMessageTaskDAO() *MessageTaskDAO {
var SharedMessageTaskDAO *MessageTaskDAO var SharedMessageTaskDAO *MessageTaskDAO
func init() {
dbs.OnReadyDone(func() {
// 清理数据任务
var ticker = time.NewTicker(time.Duration(rands.Int(24, 48)) * time.Hour)
go func() {
for range ticker.C {
err := SharedMessageTaskDAO.CleanExpiredMessageTasks(nil, 30) // 只保留30天
if err != nil {
remotelogs.Error("SharedMessageTaskDAO", "clean expired data failed: "+err.Error())
}
}
}()
})
}
func init() { func init() {
dbs.OnReady(func() { dbs.OnReady(func() {
SharedMessageTaskDAO = NewMessageTaskDAO() SharedMessageTaskDAO = NewMessageTaskDAO()
@@ -73,13 +94,46 @@ func (this *MessageTaskDAO) FindEnabledMessageTask(tx *dbs.Tx, id int64) (*Messa
// CreateMessageTask 创建任务 // CreateMessageTask 创建任务
func (this *MessageTaskDAO) CreateMessageTask(tx *dbs.Tx, recipientId int64, instanceId int64, user string, subject string, body string, isPrimary bool) (int64, error) { func (this *MessageTaskDAO) CreateMessageTask(tx *dbs.Tx, recipientId int64, instanceId int64, user string, subject string, body string, isPrimary bool) (int64, error) {
var hash = stringutil.Md5(types.String(recipientId) + "@" + types.String(instanceId) + "@" + user + "@" + subject + "@" + types.String(isPrimary))
recipientInstanceId, err := SharedMessageRecipientDAO.FindRecipientInstanceId(tx, recipientId)
if err != nil {
return 0, err
}
if recipientInstanceId > 0 {
hashLifeSeconds, err := SharedMessageMediaInstanceDAO.FindInstanceHashLifeSeconds(tx, recipientInstanceId)
if err != nil {
return 0, err
}
if hashLifeSeconds >= 0 { // 意味着此值如果小于0则不做判断
lastMessageAt, err := this.Query(tx).
Attr("hash", hash).
Result("createdAt").
DescPk().
FindInt64Col(0)
if err != nil {
return 0, err
}
// 对于同一个人N分钟内消息不重复发送
if hashLifeSeconds <= 0 {
hashLifeSeconds = 60
}
if lastMessageAt > 0 && time.Now().Unix()-lastMessageAt < int64(hashLifeSeconds) {
return 0, nil
}
}
}
op := NewMessageTaskOperator() op := NewMessageTaskOperator()
op.RecipientId = recipientId op.RecipientId = recipientId
op.InstanceId = instanceId op.InstanceId = instanceId
op.Hash = hash
op.User = user op.User = user
op.Subject = subject op.Subject = subject
op.Body = body op.Body = body
op.IsPrimary = isPrimary op.IsPrimary = isPrimary
op.Day = timeutil.Format("Ymd")
op.Status = MessageTaskStatusNone op.Status = MessageTaskStatusNone
op.State = MessageTaskStateEnabled op.State = MessageTaskStateEnabled
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
@@ -93,6 +147,8 @@ func (this *MessageTaskDAO) FindSendingMessageTasks(tx *dbs.Tx, size int64) (res
_, err = this.Query(tx). _, err = this.Query(tx).
State(MessageTaskStateEnabled). State(MessageTaskStateEnabled).
Attr("status", MessageTaskStatusNone). Attr("status", MessageTaskStatusNone).
Where("(recipientId=0 OR recipientId IN (SELECT id FROM "+SharedMessageRecipientDAO.Table+" WHERE state=1 AND isOn=1 AND (timeFrom IS NULL OR timeTo IS NULL OR :time BETWEEN timeFrom AND timeTo)))").
Param("time", timeutil.Format("H:i:s")).
Desc("isPrimary"). Desc("isPrimary").
AscPk(). AscPk().
Limit(size). Limit(size).
@@ -101,6 +157,28 @@ func (this *MessageTaskDAO) FindSendingMessageTasks(tx *dbs.Tx, size int64) (res
return return
} }
// CountMessageTasksWithStatus 根据状态计算任务数量
func (this *MessageTaskDAO) CountMessageTasksWithStatus(tx *dbs.Tx, status MessageTaskStatus) (int64, error) {
return this.Query(tx).
State(MessageTaskStateEnabled).
Attr("status", status).
Count()
}
// ListMessageTasksWithStatus 根据状态列出单页任务
func (this *MessageTaskDAO) ListMessageTasksWithStatus(tx *dbs.Tx, status MessageTaskStatus, offset int64, size int64) (result []*MessageTask, err error) {
_, err = this.Query(tx).
State(MessageTaskStateEnabled).
Attr("status", status).
Desc("isPrimary").
AscPk().
Offset(offset).
Limit(size).
Slice(&result).
FindAll()
return
}
// UpdateMessageTaskStatus 设置发送的状态 // UpdateMessageTaskStatus 设置发送的状态
func (this *MessageTaskDAO) UpdateMessageTaskStatus(tx *dbs.Tx, taskId int64, status MessageTaskStatus, result []byte) error { func (this *MessageTaskDAO) UpdateMessageTaskStatus(tx *dbs.Tx, taskId int64, status MessageTaskStatus, result []byte) error {
if taskId <= 0 { if taskId <= 0 {
@@ -117,8 +195,8 @@ func (this *MessageTaskDAO) UpdateMessageTaskStatus(tx *dbs.Tx, taskId int64, st
} }
// CreateMessageTasks 从集群、节点或者服务中创建任务 // CreateMessageTasks 从集群、节点或者服务中创建任务
func (this *MessageTaskDAO) CreateMessageTasks(tx *dbs.Tx, target MessageTaskTarget, messageType MessageType, subject string, body string) error { func (this *MessageTaskDAO) CreateMessageTasks(tx *dbs.Tx, role nodeconfigs.NodeRole, clusterId int64, nodeId int64, serverId int64, messageType MessageType, subject string, body string) error {
receivers, err := SharedMessageReceiverDAO.FindAllEnabledReceivers(tx, target, messageType) receivers, err := SharedMessageReceiverDAO.FindEnabledBestFitReceivers(tx, role, clusterId, nodeId, serverId, messageType)
if err != nil { if err != nil {
return err return err
} }
@@ -150,3 +228,16 @@ func (this *MessageTaskDAO) CreateMessageTasks(tx *dbs.Tx, target MessageTaskTar
return nil return nil
} }
// CleanExpiredMessageTasks 清理
func (this *MessageTaskDAO) CleanExpiredMessageTasks(tx *dbs.Tx, days int) error {
if days <= 0 {
days = 30
}
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days))
_, err := this.Query(tx).
Where("(day IS NULL OR day<:day)").
Param("day", day).
Delete()
return err
}

View File

@@ -3,4 +3,20 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
"testing"
) )
func TestMessageTaskDAO_FindSendingMessageTasks(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
tasks, err := NewMessageTaskDAO().FindSendingMessageTasks(tx, 100)
if err != nil {
t.Fatal(err)
}
t.Log(len(tasks), "tasks")
for _, task := range tasks {
t.Log("task:", task.Id, "recipient:", task.RecipientId)
}
}

View File

@@ -1,13 +1,32 @@
package models package models
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
) )
type MessageTaskLogDAO dbs.DAO type MessageTaskLogDAO dbs.DAO
func init() {
dbs.OnReadyDone(func() {
// 清理数据任务
var ticker = time.NewTicker(time.Duration(rands.Int(24, 48)) * time.Hour)
go func() {
for range ticker.C {
err := SharedMessageTaskLogDAO.CleanExpiredLogs(nil, 30) // 只保留30天
if err != nil {
remotelogs.Error("SharedMessageTaskLogDAO", "clean expired data failed: "+err.Error())
}
}
}()
})
}
func NewMessageTaskLogDAO() *MessageTaskLogDAO { func NewMessageTaskLogDAO() *MessageTaskLogDAO {
return dbs.NewDAO(&MessageTaskLogDAO{ return dbs.NewDAO(&MessageTaskLogDAO{
DAOObject: dbs.DAOObject{ DAOObject: dbs.DAOObject{
@@ -27,25 +46,28 @@ func init() {
}) })
} }
// 创建日志 // CreateLog 创建日志
func (this *MessageTaskLogDAO) CreateLog(tx *dbs.Tx, taskId int64, isOk bool, errMsg string, response string) error { func (this *MessageTaskLogDAO) CreateLog(tx *dbs.Tx, taskId int64, isOk bool, errMsg string, response string) error {
op := NewMessageTaskLogOperator() op := NewMessageTaskLogOperator()
op.TaskId = taskId op.TaskId = taskId
op.IsOk = isOk op.IsOk = isOk
op.Error = errMsg op.Error = errMsg
op.Response = response op.Response = response
op.Day = timeutil.Format("Ymd")
return this.Save(tx, op) return this.Save(tx, op)
} }
// 计算日志数量 // CountLogs 计算日志数量
func (this *MessageTaskLogDAO) CountLogs(tx *dbs.Tx) (int64, error) { func (this *MessageTaskLogDAO) CountLogs(tx *dbs.Tx) (int64, error) {
return this.Query(tx). return this.Query(tx).
Where("taskId IN (SELECT id FROM " + SharedMessageTaskDAO.Table + ")").
Count() Count()
} }
// 列出单页日志 // ListLogs 列出单页日志
func (this *MessageTaskLogDAO) ListLogs(tx *dbs.Tx, offset int64, size int64) (result []*MessageTaskLog, err error) { func (this *MessageTaskLogDAO) ListLogs(tx *dbs.Tx, offset int64, size int64) (result []*MessageTaskLog, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Where("taskId IN (SELECT id FROM " + SharedMessageTaskDAO.Table + ")").
Offset(offset). Offset(offset).
Limit(size). Limit(size).
DescPk(). DescPk().
@@ -53,3 +75,16 @@ func (this *MessageTaskLogDAO) ListLogs(tx *dbs.Tx, offset int64, size int64) (r
FindAll() FindAll()
return return
} }
// CleanExpiredLogs 清理
func (this *MessageTaskLogDAO) CleanExpiredLogs(tx *dbs.Tx, days int) error {
if days <= 0 {
days = 30
}
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days))
_, err := this.Query(tx).
Where("(day IS NULL OR day<:day)").
Param("day", day).
Delete()
return err
}

View File

@@ -1,6 +1,6 @@
package models package models
// 消息发送日志 // MessageTaskLog 消息发送日志
type MessageTaskLog struct { type MessageTaskLog struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
TaskId uint64 `field:"taskId"` // 任务ID TaskId uint64 `field:"taskId"` // 任务ID
@@ -8,6 +8,7 @@ type MessageTaskLog struct {
IsOk uint8 `field:"isOk"` // 是否成功 IsOk uint8 `field:"isOk"` // 是否成功
Error string `field:"error"` // 错误信息 Error string `field:"error"` // 错误信息
Response string `field:"response"` // 响应信息 Response string `field:"response"` // 响应信息
Day string `field:"day"` // YYYYMMDD
} }
type MessageTaskLogOperator struct { type MessageTaskLogOperator struct {
@@ -17,6 +18,7 @@ type MessageTaskLogOperator struct {
IsOk interface{} // 是否成功 IsOk interface{} // 是否成功
Error interface{} // 错误信息 Error interface{} // 错误信息
Response interface{} // 响应信息 Response interface{} // 响应信息
Day interface{} // YYYYMMDD
} }
func NewMessageTaskLogOperator() *MessageTaskLogOperator { func NewMessageTaskLogOperator() *MessageTaskLogOperator {

View File

@@ -1,9 +1,10 @@
package models package models
// // MessageTask 消息发送相关任务
type MessageTask struct { type MessageTask struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
RecipientId uint32 `field:"recipientId"` // 接收人ID RecipientId uint32 `field:"recipientId"` // 接收人ID
Hash string `field:"hash"` // SUM标识
InstanceId uint32 `field:"instanceId"` // 媒介实例ID InstanceId uint32 `field:"instanceId"` // 媒介实例ID
User string `field:"user"` // 接收用户标识 User string `field:"user"` // 接收用户标识
Subject string `field:"subject"` // 标题 Subject string `field:"subject"` // 标题
@@ -13,12 +14,14 @@ type MessageTask struct {
SentAt uint64 `field:"sentAt"` // 最后一次发送时间 SentAt uint64 `field:"sentAt"` // 最后一次发送时间
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
Result string `field:"result"` // 结果 Result string `field:"result"` // 结果
Day string `field:"day"` // YYYYMMDD
IsPrimary uint8 `field:"isPrimary"` // 是否优先 IsPrimary uint8 `field:"isPrimary"` // 是否优先
} }
type MessageTaskOperator struct { type MessageTaskOperator struct {
Id interface{} // ID Id interface{} // ID
RecipientId interface{} // 接收人ID RecipientId interface{} // 接收人ID
Hash interface{} // SUM标识
InstanceId interface{} // 媒介实例ID InstanceId interface{} // 媒介实例ID
User interface{} // 接收用户标识 User interface{} // 接收用户标识
Subject interface{} // 标题 Subject interface{} // 标题
@@ -28,6 +31,7 @@ type MessageTaskOperator struct {
SentAt interface{} // 最后一次发送时间 SentAt interface{} // 最后一次发送时间
State interface{} // 状态 State interface{} // 状态
Result interface{} // 结果 Result interface{} // 结果
Day interface{} // YYYYMMDD
IsPrimary interface{} // 是否优先 IsPrimary interface{} // 是否优先
} }

View File

@@ -197,6 +197,7 @@ func (this *MetricStatDAO) FindItemStatsWithClusterIdAndLastTime(tx *dbs.Tx, clu
var lastTime = lastStat.Time var lastTime = lastStat.Time
var query = this.Query(tx). var query = this.Query(tx).
UseIndex("cluster_item_time").
Attr("clusterId", clusterId). Attr("clusterId", clusterId).
Attr("itemId", itemId). Attr("itemId", itemId).
Attr("version", version). Attr("version", version).
@@ -243,6 +244,7 @@ func (this *MetricStatDAO) FindItemStatsWithNodeIdAndLastTime(tx *dbs.Tx, nodeId
var lastStat = statOne.(*MetricStat) var lastStat = statOne.(*MetricStat)
var lastTime = lastStat.Time var lastTime = lastStat.Time
var query = this.Query(tx). var query = this.Query(tx).
UseIndex("node_item_time").
Attr("nodeId", nodeId). Attr("nodeId", nodeId).
Attr("itemId", itemId). Attr("itemId", itemId).
Attr("version", version). Attr("version", version).
@@ -290,6 +292,7 @@ func (this *MetricStatDAO) FindItemStatsWithServerIdAndLastTime(tx *dbs.Tx, serv
var lastTime = lastStat.Time var lastTime = lastStat.Time
var query = this.Query(tx). var query = this.Query(tx).
UseIndex("server_item_time").
Attr("serverId", serverId). Attr("serverId", serverId).
Attr("itemId", itemId). Attr("itemId", itemId).
Attr("version", version). Attr("version", version).

View File

@@ -3,16 +3,20 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap" _ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"testing" "testing"
) )
func TestNewMetricStatDAO_InsertMany(t *testing.T) { func TestNewMetricStatDAO_InsertMany(t *testing.T) {
for i := 0; i <= 1; i++ { for i := 0; i <= 10_000_000; i++ {
err := NewMetricStatDAO().CreateStat(nil, types.String(i) + "_v1", 18, 48, 23, 25, []string{"/html" + types.String(i)}, 1, "20210728", 0) err := NewMetricStatDAO().CreateStat(nil, types.String(i)+"_v1", 18, int64(rands.Int(0, 10000)), int64(rands.Int(0, 10000)), int64(rands.Int(0, 100)), []string{"/html" + types.String(i)}, 1, "20210830", 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if i % 10000 == 0 {
t.Log(i)
}
} }
t.Log("done") t.Log("done")
} }

View File

@@ -4,10 +4,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
@@ -195,7 +193,7 @@ func (this *NodeClusterDAO) CountAllEnabledClusters(tx *dbs.Tx, keyword string)
query := this.Query(tx). query := this.Query(tx).
State(NodeClusterStateEnabled) State(NodeClusterStateEnabled)
if len(keyword) > 0 { if len(keyword) > 0 {
query.Where("(name LIKE :keyword OR dnsName like :keyword)"). query.Where("(name LIKE :keyword OR dnsName like :keyword OR (dnsDomainId > 0 AND dnsDomainId IN (SELECT id FROM "+dns.SharedDNSDomainDAO.Table+" WHERE name LIKE :keyword AND state=1)))").
Param("keyword", "%"+keyword+"%") Param("keyword", "%"+keyword+"%")
} }
return query.Count() return query.Count()
@@ -206,7 +204,7 @@ func (this *NodeClusterDAO) ListEnabledClusters(tx *dbs.Tx, keyword string, offs
query := this.Query(tx). query := this.Query(tx).
State(NodeClusterStateEnabled) State(NodeClusterStateEnabled)
if len(keyword) > 0 { if len(keyword) > 0 {
query.Where("(name LIKE :keyword OR dnsName like :keyword)"). query.Where("(name LIKE :keyword OR dnsName like :keyword OR (dnsDomainId > 0 AND dnsDomainId IN (SELECT id FROM "+dns.SharedDNSDomainDAO.Table+" WHERE name LIKE :keyword AND state=1)))").
Param("keyword", "%"+keyword+"%") Param("keyword", "%"+keyword+"%")
} }
_, err = query. _, err = query.
@@ -303,11 +301,8 @@ func (this *NodeClusterDAO) UpdateClusterHealthCheck(tx *dbs.Tx, clusterId int64
op := NewNodeClusterOperator() op := NewNodeClusterOperator()
op.Id = clusterId op.Id = clusterId
op.HealthCheck = healthCheckJSON op.HealthCheck = healthCheckJSON
err := this.Save(tx, op) // 不需要通知更新
if err != nil { return this.Save(tx, op)
return err
}
return this.NotifyUpdate(tx, clusterId)
} }
// CountAllEnabledClustersWithGrantId 计算使用某个认证的集群数量 // CountAllEnabledClustersWithGrantId 计算使用某个认证的集群数量
@@ -406,7 +401,16 @@ func (this *NodeClusterDAO) FindClusterGrantId(tx *dbs.Tx, clusterId int64) (int
} }
// FindClusterDNSInfo 查找DNS信息 // FindClusterDNSInfo 查找DNS信息
func (this *NodeClusterDAO) FindClusterDNSInfo(tx *dbs.Tx, clusterId int64) (*NodeCluster, error) { func (this *NodeClusterDAO) FindClusterDNSInfo(tx *dbs.Tx, clusterId int64, cacheMap maps.Map) (*NodeCluster, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":record:" + types.String(clusterId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*NodeCluster), nil
}
one, err := this.Query(tx). one, err := this.Query(tx).
Pk(clusterId). Pk(clusterId).
Result("id", "name", "dnsName", "dnsDomainId", "dns", "isOn"). Result("id", "name", "dnsName", "dnsDomainId", "dns", "isOn").
@@ -417,6 +421,7 @@ func (this *NodeClusterDAO) FindClusterDNSInfo(tx *dbs.Tx, clusterId int64) (*No
if one == nil { if one == nil {
return nil, nil return nil, nil
} }
cacheMap[cacheKey] = one
return one.(*NodeCluster), nil return one.(*NodeCluster), nil
} }
@@ -461,118 +466,6 @@ func (this *NodeClusterDAO) UpdateClusterDNS(tx *dbs.Tx, clusterId int64, dnsNam
return this.NotifyDNSUpdate(tx, clusterId) return this.NotifyDNSUpdate(tx, clusterId)
} }
// CheckClusterDNS 检查集群的DNS问题
func (this *NodeClusterDAO) CheckClusterDNS(tx *dbs.Tx, cluster *NodeCluster) (issues []*pb.DNSIssue, err error) {
clusterId := int64(cluster.Id)
domainId := int64(cluster.DnsDomainId)
// 检查域名
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, domainId)
if err != nil {
return nil, err
}
if domain == nil {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "域名选择错误,需要重新选择",
Params: nil,
})
return
}
// 检查二级域名
if len(cluster.DnsName) == 0 {
issues = append(issues, &pb.DNSIssue{
Target: cluster.Name,
TargetId: clusterId,
Type: "cluster",
Description: "没有设置二级域名",
Params: nil,
})
return
}
// TODO 检查域名格式
// TODO 检查域名是否已解析
// 检查节点
nodes, err := SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, clusterId, true)
if err != nil {
return nil, err
}
// TODO 检查节点数量不能为0
for _, node := range nodes {
nodeId := int64(node.Id)
routeCodes, err := node.DNSRouteCodesForDomainId(domainId)
if err != nil {
return nil, err
}
if len(routeCodes) == 0 {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "没有选择节点所属线路",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
})
continue
}
// 检查线路是否在已有线路中
for _, routeCode := range routeCodes {
routeOk, err := domain.ContainsRouteCode(routeCode)
if err != nil {
return nil, err
}
if !routeOk {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "线路已经失效,请重新选择",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
})
continue
}
}
// 检查IP地址
ipAddr, err := SharedNodeIPAddressDAO.FindFirstNodeAccessIPAddress(tx, nodeId, nodeconfigs.NodeRoleNode)
if err != nil {
return nil, err
}
if len(ipAddr) == 0 {
issues = append(issues, &pb.DNSIssue{
Target: node.Name,
TargetId: nodeId,
Type: "node",
Description: "没有设置IP地址",
Params: map[string]string{
"clusterName": cluster.Name,
"clusterId": numberutils.FormatInt64(clusterId),
},
})
continue
}
// TODO 检查是否有解析记录
}
return
}
// FindClusterAdminId 查找集群所属管理员 // FindClusterAdminId 查找集群所属管理员
func (this *NodeClusterDAO) FindClusterAdminId(tx *dbs.Tx, clusterId int64) (int64, error) { func (this *NodeClusterDAO) FindClusterAdminId(tx *dbs.Tx, clusterId int64) (int64, error) {
return this.Query(tx). return this.Query(tx).
@@ -682,11 +575,27 @@ func (this *NodeClusterDAO) FindAllEnabledNodeClusterIdsWithCachePolicyId(tx *db
} }
// FindClusterHTTPFirewallPolicyId 获取集群的WAF策略ID // FindClusterHTTPFirewallPolicyId 获取集群的WAF策略ID
func (this *NodeClusterDAO) FindClusterHTTPFirewallPolicyId(tx *dbs.Tx, clusterId int64) (int64, error) { func (this *NodeClusterDAO) FindClusterHTTPFirewallPolicyId(tx *dbs.Tx, clusterId int64, cacheMap maps.Map) (int64, error) {
return this.Query(tx). if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":FindClusterHTTPFirewallPolicyId:" + types.String(clusterId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(int64), nil
}
firewallPolicyId, err := this.Query(tx).
Pk(clusterId). Pk(clusterId).
Result("httpFirewallPolicyId"). Result("httpFirewallPolicyId").
FindInt64Col(0) FindInt64Col(0)
if err != nil {
return 0, err
}
cacheMap[cacheKey] = firewallPolicyId
return firewallPolicyId, nil
} }
// UpdateNodeClusterHTTPCachePolicyId 设置集群的缓存策略 // UpdateNodeClusterHTTPCachePolicyId 设置集群的缓存策略
@@ -702,11 +611,27 @@ func (this *NodeClusterDAO) UpdateNodeClusterHTTPCachePolicyId(tx *dbs.Tx, clust
} }
// FindClusterHTTPCachePolicyId 获取集群的缓存策略ID // FindClusterHTTPCachePolicyId 获取集群的缓存策略ID
func (this *NodeClusterDAO) FindClusterHTTPCachePolicyId(tx *dbs.Tx, clusterId int64) (int64, error) { func (this *NodeClusterDAO) FindClusterHTTPCachePolicyId(tx *dbs.Tx, clusterId int64, cacheMap maps.Map) (int64, error) {
return this.Query(tx). if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":FindClusterHTTPCachePolicyId:" + types.String(clusterId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(int64), nil
}
cachePolicyId, err := this.Query(tx).
Pk(clusterId). Pk(clusterId).
Result("cachePolicyId"). Result("cachePolicyId").
FindInt64Col(0) FindInt64Col(0)
if err != nil {
return 0, err
}
cacheMap[cacheKey] = cachePolicyId
return cachePolicyId, nil
} }
// UpdateNodeClusterHTTPFirewallPolicyId 设置集群的WAF策略 // UpdateNodeClusterHTTPFirewallPolicyId 设置集群的WAF策略

View File

@@ -2,6 +2,7 @@ package models
import ( import (
"encoding/json" "encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
@@ -107,6 +108,19 @@ func (this *NodeDAO) FindEnabledNode(tx *dbs.Tx, id int64) (*Node, error) {
return result.(*Node), err return result.(*Node), err
} }
// FindEnabledBasicNode 获取节点的基本信息
func (this *NodeDAO) FindEnabledBasicNode(tx *dbs.Tx, nodeId int64) (*Node, error) {
one, err := this.Query(tx).
State(NodeStateEnabled).
Pk(nodeId).
Result("id", "name", "clusterId", "isOn", "isUp").
Find()
if one == nil {
return nil, err
}
return one.(*Node), nil
}
// FindNodeName 根据主键查找名称 // FindNodeName 根据主键查找名称
func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id int64) (string, error) { func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id int64) (string, error) {
name, err := this.Query(tx). name, err := this.Query(tx).
@@ -118,6 +132,19 @@ func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id int64) (string, error) {
// CreateNode 创建节点 // CreateNode 创建节点
func (this *NodeDAO) CreateNode(tx *dbs.Tx, adminId int64, name string, clusterId int64, groupId int64, regionId int64) (nodeId int64, err error) { func (this *NodeDAO) CreateNode(tx *dbs.Tx, adminId int64, name string, clusterId int64, groupId int64, regionId int64) (nodeId int64, err error) {
// 检查节点数量
if teaconst.MaxNodes > 0 {
count, err := this.Query(tx).
State(NodeStateEnabled).
Count()
if err != nil {
return 0, err
}
if int64(teaconst.MaxNodes) <= count {
return 0, errors.New("[企业版]超出最大节点数限制:" + types.String(teaconst.MaxNodes) + ",请购买更多配额")
}
}
uniqueId, err := this.GenUniqueId(tx) uniqueId, err := this.GenUniqueId(tx)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -167,6 +194,13 @@ func (this *NodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, name string, clusterId
if nodeId <= 0 { if nodeId <= 0 {
return errors.New("invalid nodeId") return errors.New("invalid nodeId")
} }
// 老的集群
oldClusterIds, err := this.FindEnabledNodeClusterIds(tx, nodeId)
if err != nil {
return err
}
op := NewNodeOperator() op := NewNodeOperator()
op.Id = nodeId op.Id = nodeId
op.Name = name op.Name = name
@@ -210,6 +244,16 @@ func (this *NodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, name string, clusterId
return err return err
} }
// 通知老的集群更新
for _, oldClusterId := range oldClusterIds {
if oldClusterId != clusterId && !lists.ContainsInt64(secondaryClusterIds, oldClusterId) {
err = dns.SharedDNSTaskDAO.CreateClusterTask(tx, oldClusterId, dns.DNSTaskTypeClusterChange)
if err != nil {
return err
}
}
}
return this.NotifyDNSUpdate(tx, nodeId) return this.NotifyDNSUpdate(tx, nodeId)
} }
@@ -385,6 +429,30 @@ func (this *NodeDAO) FindEnabledAndOnNodeClusterIds(tx *dbs.Tx, nodeId int64) (r
return return
} }
// FindEnabledNodeClusterIds 获取节点所属所有可用的集群ID
func (this *NodeDAO) FindEnabledNodeClusterIds(tx *dbs.Tx, nodeId int64) (result []int64, err error) {
one, err := this.Query(tx).
Pk(nodeId).
Result("clusterId", "secondaryClusterIds").
Find()
if one == nil {
return nil, err
}
var clusterId = int64(one.(*Node).ClusterId)
if clusterId > 0 {
result = append(result, clusterId)
}
for _, clusterId := range one.(*Node).DecodeSecondaryClusterIds() {
if lists.ContainsInt64(result, clusterId) {
continue
}
result = append(result, clusterId)
}
return
}
// FindAllNodeIdsMatch 匹配节点并返回节点ID // FindAllNodeIdsMatch 匹配节点并返回节点ID
func (this *NodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64, includeSecondaryNodes bool, isOn configutils.BoolState) (result []int64, err error) { func (this *NodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64, includeSecondaryNodes bool, isOn configutils.BoolState) (result []int64, err error) {
query := this.Query(tx) query := this.Query(tx)
@@ -614,7 +682,7 @@ func (this *NodeDAO) UpdateNodeInstallStatus(tx *dbs.Tx, nodeId int64, status *N
// ComposeNodeConfig 组合配置 // ComposeNodeConfig 组合配置
// TODO 提升运行速度 // TODO 提升运行速度
func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.NodeConfig, error) { func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, cacheMap maps.Map) (*nodeconfigs.NodeConfig, error) {
node, err := this.FindEnabledNode(tx, nodeId) node, err := this.FindEnabledNode(tx, nodeId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -642,11 +710,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N
} }
for _, server := range servers { for _, server := range servers {
if len(server.Config) == 0 { serverConfig, err := SharedServerDAO.ComposeServerConfig(tx, server, cacheMap)
continue
}
serverConfig, err := SharedServerDAO.ComposeServerConfig(tx, server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -675,12 +739,12 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N
var clusterIds = []int64{primaryClusterId} var clusterIds = []int64{primaryClusterId}
clusterIds = append(clusterIds, node.DecodeSecondaryClusterIds()...) clusterIds = append(clusterIds, node.DecodeSecondaryClusterIds()...)
for _, clusterId := range clusterIds { for _, clusterId := range clusterIds {
httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if httpFirewallPolicyId > 0 { if httpFirewallPolicyId > 0 {
firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId) firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -690,12 +754,12 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*nodeconfigs.N
} }
// 缓存策略 // 缓存策略
httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if httpCachePolicyId > 0 { if httpCachePolicyId > 0 {
cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId) cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1102,9 +1166,11 @@ func (this *NodeDAO) UpdateNodeUpCount(tx *dbs.Tx, nodeId int64, isUp bool, maxU
return false, err return false, err
} }
err = this.NotifyDNSUpdate(tx, nodeId) if changed {
if err != nil { err = this.NotifyDNSUpdate(tx, nodeId)
return false, err if err != nil {
return true, err
}
} }
return return
@@ -1115,15 +1181,19 @@ func (this *NodeDAO) UpdateNodeUp(tx *dbs.Tx, nodeId int64, isUp bool) error {
if nodeId <= 0 { if nodeId <= 0 {
return errors.New("invalid nodeId") return errors.New("invalid nodeId")
} }
op := NewNodeOperator() op := NewNodeOperator()
op.Id = nodeId op.Id = nodeId
op.IsUp = isUp op.IsUp = isUp
op.CountDown = 0 op.CountUp = 0
op.CountDown = 0 op.CountDown = 0
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return err return err
} }
// TODO 只有前后状态不一致的时候才需要更新DNS
return this.NotifyDNSUpdate(tx, nodeId) return this.NotifyDNSUpdate(tx, nodeId)
} }
@@ -1302,7 +1372,7 @@ func (this *NodeDAO) NotifyDNSUpdate(tx *dbs.Tx, nodeId int64) error {
return err return err
} }
for _, clusterId := range clusterIds { for _, clusterId := range clusterIds {
dnsInfo, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) dnsInfo, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,7 +3,9 @@ package models
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"testing" "testing"
"time"
) )
func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) { func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) {
@@ -34,3 +36,23 @@ func TestNodeDAO_FindEnabledNodeClusterIds(t *testing.T) {
} }
t.Log(clusterIds) t.Log(clusterIds)
} }
func TestNodeDAO_ComposeNodeConfig(t *testing.T) {
dbs.NotifyReady()
before := time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
var tx *dbs.Tx
var cacheMap = maps.Map{}
nodeConfig, err := SharedNodeDAO.ComposeNodeConfig(tx, 48, cacheMap)
if err != nil {
t.Fatal(err)
}
t.Log(len(nodeConfig.Servers), "servers")
t.Log(len(cacheMap), "items")
// old: 77ms => new: 56ms
}

View File

@@ -2,6 +2,8 @@ package models
import ( import (
"errors" "errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
@@ -36,21 +38,27 @@ func init() {
} }
// EnableAddress 启用条目 // EnableAddress 启用条目
func (this *NodeIPAddressDAO) EnableAddress(tx *dbs.Tx, id int64) (err error) { func (this *NodeIPAddressDAO) EnableAddress(tx *dbs.Tx, addressId int64) (err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Pk(id). Pk(addressId).
Set("state", NodeIPAddressStateEnabled). Set("state", NodeIPAddressStateEnabled).
Update() Update()
return err if err != nil {
return err
}
return this.NotifyUpdate(tx, addressId)
} }
// DisableAddress 禁用IP地址 // DisableAddress 禁用IP地址
func (this *NodeIPAddressDAO) DisableAddress(tx *dbs.Tx, id int64) (err error) { func (this *NodeIPAddressDAO) DisableAddress(tx *dbs.Tx, addressId int64) (err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Pk(id). Pk(addressId).
Set("state", NodeIPAddressStateDisabled). Set("state", NodeIPAddressStateDisabled).
Update() Update()
return err if err != nil {
return err
}
return this.NotifyUpdate(tx, addressId)
} }
// DisableAllAddressesWithNodeId 禁用节点的所有的IP地址 // DisableAllAddressesWithNodeId 禁用节点的所有的IP地址
@@ -66,7 +74,11 @@ func (this *NodeIPAddressDAO) DisableAllAddressesWithNodeId(tx *dbs.Tx, nodeId i
Attr("role", role). Attr("role", role).
Set("state", NodeIPAddressStateDisabled). Set("state", NodeIPAddressStateDisabled).
Update() Update()
return err if err != nil {
return err
}
return SharedNodeDAO.NotifyDNSUpdate(tx, nodeId)
} }
// FindEnabledAddress 查找启用中的IP地址 // FindEnabledAddress 查找启用中的IP地址
@@ -90,7 +102,7 @@ func (this *NodeIPAddressDAO) FindAddressName(tx *dbs.Tx, id int64) (string, err
} }
// CreateAddress 创建IP地址 // CreateAddress 创建IP地址
func (this *NodeIPAddressDAO) CreateAddress(tx *dbs.Tx, nodeId int64, role nodeconfigs.NodeRole, name string, ip string, canAccess bool) (addressId int64, err error) { func (this *NodeIPAddressDAO) CreateAddress(tx *dbs.Tx, adminId int64, nodeId int64, role nodeconfigs.NodeRole, name string, ip string, canAccess bool, thresholdsJSON []byte) (addressId int64, err error) {
if len(role) == 0 { if len(role) == 0 {
role = nodeconfigs.NodeRoleNode role = nodeconfigs.NodeRoleNode
} }
@@ -101,8 +113,15 @@ func (this *NodeIPAddressDAO) CreateAddress(tx *dbs.Tx, nodeId int64, role nodec
op.Name = name op.Name = name
op.Ip = ip op.Ip = ip
op.CanAccess = canAccess op.CanAccess = canAccess
if len(thresholdsJSON) > 0 {
op.Thresholds = thresholdsJSON
} else {
op.Thresholds = "[]"
}
op.State = NodeIPAddressStateEnabled op.State = NodeIPAddressStateEnabled
err = this.Save(tx, op) addressId, err = this.SaveInt64(tx, op)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -112,11 +131,17 @@ func (this *NodeIPAddressDAO) CreateAddress(tx *dbs.Tx, nodeId int64, role nodec
return 0, err return 0, err
} }
return types.Int64(op.Id), nil // 创建日志
err = SharedNodeIPAddressLogDAO.CreateLog(tx, adminId, addressId, "创建IP")
if err != nil {
return 0, err
}
return addressId, nil
} }
// UpdateAddress 修改IP地址 // UpdateAddress 修改IP地址
func (this *NodeIPAddressDAO) UpdateAddress(tx *dbs.Tx, addressId int64, name string, ip string, canAccess bool) (err error) { func (this *NodeIPAddressDAO) UpdateAddress(tx *dbs.Tx, adminId int64, addressId int64, name string, ip string, canAccess bool, isOn bool, thresholdsJSON []byte) (err error) {
if addressId <= 0 { if addressId <= 0 {
return errors.New("invalid addressId") return errors.New("invalid addressId")
} }
@@ -126,9 +151,27 @@ func (this *NodeIPAddressDAO) UpdateAddress(tx *dbs.Tx, addressId int64, name st
op.Name = name op.Name = name
op.Ip = ip op.Ip = ip
op.CanAccess = canAccess op.CanAccess = canAccess
op.IsOn = isOn
if len(thresholdsJSON) > 0 {
op.Thresholds = thresholdsJSON
} else {
op.Thresholds = "[]"
}
op.State = NodeIPAddressStateEnabled // 恢复状态 op.State = NodeIPAddressStateEnabled // 恢复状态
err = this.Save(tx, op) err = this.Save(tx, op)
return err if err != nil {
return err
}
// 创建日志
err = SharedNodeIPAddressLogDAO.CreateLog(tx, adminId, addressId, "修改IP")
if err != nil {
return err
}
return this.NotifyUpdate(tx, addressId)
} }
// UpdateAddressIP 修改IP地址中的IP // UpdateAddressIP 修改IP地址中的IP
@@ -140,7 +183,11 @@ func (this *NodeIPAddressDAO) UpdateAddressIP(tx *dbs.Tx, addressId int64, ip st
op.Id = addressId op.Id = addressId
op.Ip = ip op.Ip = ip
err := this.Save(tx, op) err := this.Save(tx, op)
return err if err != nil {
return err
}
return this.NotifyUpdate(tx, addressId)
} }
// UpdateAddressNodeId 修改IP地址所属节点 // UpdateAddressNodeId 修改IP地址所属节点
@@ -209,8 +256,8 @@ func (this *NodeIPAddressDAO) FindFirstNodeAccessIPAddressId(tx *dbs.Tx, nodeId
FindInt64Col(0) FindInt64Col(0)
} }
// FindNodeAccessIPAddresses 查找节点所有的可访问的IP地址 // FindNodeAccessAndUpIPAddresses 查找节点所有的可访问的IP地址
func (this *NodeIPAddressDAO) FindNodeAccessIPAddresses(tx *dbs.Tx, nodeId int64, role nodeconfigs.NodeRole) (result []*NodeIPAddress, err error) { func (this *NodeIPAddressDAO) FindNodeAccessAndUpIPAddresses(tx *dbs.Tx, nodeId int64, role nodeconfigs.NodeRole) (result []*NodeIPAddress, err error) {
if len(role) == 0 { if len(role) == 0 {
role = nodeconfigs.NodeRoleNode role = nodeconfigs.NodeRoleNode
} }
@@ -219,9 +266,122 @@ func (this *NodeIPAddressDAO) FindNodeAccessIPAddresses(tx *dbs.Tx, nodeId int64
Attr("nodeId", nodeId). Attr("nodeId", nodeId).
State(NodeIPAddressStateEnabled). State(NodeIPAddressStateEnabled).
Attr("canAccess", true). Attr("canAccess", true).
Attr("isOn", true).
Attr("isUp", true).
Desc("order"). Desc("order").
AscPk(). AscPk().
Slice(&result). Slice(&result).
FindAll() FindAll()
return return
} }
// CountAllEnabledIPAddresses 计算IP地址数量
// TODO 目前支持边缘节点将来支持NS节点
func (this *NodeIPAddressDAO) CountAllEnabledIPAddresses(tx *dbs.Tx, role string, nodeClusterId int64, upState configutils.BoolState, keyword string) (int64, error) {
var query = this.Query(tx).
State(NodeIPAddressStateEnabled).
Attr("role", role)
// 集群
if nodeClusterId > 0 {
query.Where("nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE (clusterId=:clusterId OR JSON_CONTAINS(secondaryClusterIds, :clusterIdString)) AND state=1)").
Param("clusterId", nodeClusterId).
Param("clusterIdString", types.String(nodeClusterId))
} else {
query.Where("nodeId IN (SELECT id FROM " + SharedNodeDAO.Table + " WHERE state=1 AND clusterId IN (SELECT id FROM " + SharedNodeClusterDAO.Table + " WHERE state=1))")
}
// 在线状态
switch upState {
case configutils.BoolStateYes:
query.Attr("isUp", 1)
case configutils.BoolStateNo:
query.Attr("isUp", 0)
}
// 关键词
if len(keyword) > 0 {
query.Where("(ip LIKE :keyword OR name LIKE :keyword OR description LIKE :keyword OR nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE state=1 AND name LIKE :keyword))").
Param("keyword", "%"+keyword+"%")
}
return query.Count()
}
// ListEnabledIPAddresses 列出单页的IP地址
func (this *NodeIPAddressDAO) ListEnabledIPAddresses(tx *dbs.Tx, role string, nodeClusterId int64, upState configutils.BoolState, keyword string, offset int64, size int64) (result []*NodeIPAddress, err error) {
var query = this.Query(tx).
State(NodeIPAddressStateEnabled).
Attr("role", role)
// 集群
if nodeClusterId > 0 {
query.Where("nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE (clusterId=:clusterId OR JSON_CONTAINS(secondaryClusterIds, :clusterIdString)) AND state=1)").
Param("clusterId", nodeClusterId).
Param("clusterIdString", types.String(nodeClusterId))
} else {
query.Where("nodeId IN (SELECT id FROM " + SharedNodeDAO.Table + " WHERE state=1 AND clusterId IN (SELECT id FROM " + SharedNodeClusterDAO.Table + " WHERE state=1))")
}
// 在线状态
switch upState {
case configutils.BoolStateYes:
query.Attr("isUp", 1)
case configutils.BoolStateNo:
query.Attr("isUp", 0)
}
// 关键词
if len(keyword) > 0 {
query.Where("(ip LIKE :keyword OR name LIKE :keyword OR description LIKE :keyword OR nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE state=1 AND name LIKE :keyword))").
Param("keyword", "%"+keyword+"%")
}
_, err = query.Offset(offset).
Limit(size).
Asc("isUp").
Desc("nodeId").
Slice(&result).
FindAll()
return
}
// FindAllEnabledAndOnIPAddressesWithClusterId 列出所有的正在启用的IP地址
func (this *NodeIPAddressDAO) FindAllEnabledAndOnIPAddressesWithClusterId(tx *dbs.Tx, role string, clusterId int64) (result []*NodeIPAddress, err error) {
_, err = this.Query(tx).
State(NodeIPAddressStateEnabled).
Attr("role", role).
Attr("isOn", true).
Where("nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE state=1 AND clusterId=:clusterId)").
Param("clusterId", clusterId).
Slice(&result).
FindAll()
return
}
// NotifyUpdate 通知更新
func (this *NodeIPAddressDAO) NotifyUpdate(tx *dbs.Tx, addressId int64) error {
address, err := this.Query(tx).
Pk(addressId).
Result("nodeId", "role").
Find()
if err != nil {
return err
}
if address == nil {
return nil
}
var nodeId = int64(address.(*NodeIPAddress).NodeId)
if nodeId == 0 {
return nil
}
var role = address.(*NodeIPAddress).Role
switch role {
case nodeconfigs.NodeRoleNode:
err = dns.SharedDNSTaskDAO.CreateNodeTask(tx, nodeId, dns.DNSTaskTypeNodeChange)
}
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,15 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build community
// +build community
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
)
// FireThresholds 触发阈值
func (this *NodeIPAddressDAO) FireThresholds(tx *dbs.Tx, role nodeconfigs.NodeRole, nodeId int64) error {
return nil
}

View File

@@ -0,0 +1,73 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
timeutil "github.com/iwind/TeaGo/utils/time"
)
type NodeIPAddressLogDAO dbs.DAO
func NewNodeIPAddressLogDAO() *NodeIPAddressLogDAO {
return dbs.NewDAO(&NodeIPAddressLogDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeNodeIPAddressLogs",
Model: new(NodeIPAddressLog),
PkName: "id",
},
}).(*NodeIPAddressLogDAO)
}
var SharedNodeIPAddressLogDAO *NodeIPAddressLogDAO
func init() {
dbs.OnReady(func() {
SharedNodeIPAddressLogDAO = NewNodeIPAddressLogDAO()
})
}
// CreateLog 创建日志
func (this *NodeIPAddressLogDAO) CreateLog(tx *dbs.Tx, adminId int64, addrId int64, description string) error {
addr, err := SharedNodeIPAddressDAO.FindEnabledAddress(tx, addrId)
if err != nil {
return err
}
if addr == nil {
return nil
}
var op = NewNodeIPAddressLogOperator()
op.AdminId = adminId
op.AddressId = addrId
op.Description = description
op.CanAccess = addr.CanAccess
op.IsOn = addr.IsOn
op.IsUp = addr.IsUp
op.Day = timeutil.Format("Ymd")
return this.Save(tx, op)
}
// CountLogs 计算日志数量
func (this *NodeIPAddressLogDAO) CountLogs(tx *dbs.Tx, addrId int64) (int64, error) {
var query = this.Query(tx)
if addrId > 0 {
query.Attr("addressId", addrId)
}
return query.Count()
}
// ListLogs 列出单页日志
func (this *NodeIPAddressLogDAO) ListLogs(tx *dbs.Tx, addrId int64, offset int64, size int64) (result []*NodeIPAddressLog, err error) {
var query = this.Query(tx)
if addrId > 0 {
query.Attr("addressId", addrId)
}
_, err = query.Offset(offset).
Limit(size).
DescPk().
Slice(&result).
FindAll()
return
}

View File

@@ -0,0 +1,6 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,30 @@
package models
// NodeIPAddressLog IP状态变更日志
type NodeIPAddressLog struct {
Id uint64 `field:"id"` // ID
AddressId uint64 `field:"addressId"` // 地址ID
AdminId uint32 `field:"adminId"` // 管理员ID
Description string `field:"description"` // 描述
CreatedAt uint64 `field:"createdAt"` // 操作时间
IsUp uint8 `field:"isUp"` // 是否在线
IsOn uint8 `field:"isOn"` // 是否启用
CanAccess uint8 `field:"canAccess"` // 是否可访问
Day string `field:"day"` // YYYYMMDD用来清理
}
type NodeIPAddressLogOperator struct {
Id interface{} // ID
AddressId interface{} // 地址ID
AdminId interface{} // 管理员ID
Description interface{} // 描述
CreatedAt interface{} // 操作时间
IsUp interface{} // 是否在线
IsOn interface{} // 是否启用
CanAccess interface{} // 是否可访问
Day interface{} // YYYYMMDD用来清理
}
func NewNodeIPAddressLogOperator() *NodeIPAddressLogOperator {
return &NodeIPAddressLogOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -11,6 +11,9 @@ type NodeIPAddress struct {
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
Order uint32 `field:"order"` // 排序 Order uint32 `field:"order"` // 排序
CanAccess uint8 `field:"canAccess"` // 是否可以访问 CanAccess uint8 `field:"canAccess"` // 是否可以访问
IsOn uint8 `field:"isOn"` // 是否启用
IsUp uint8 `field:"isUp"` // 是否上线
Thresholds string `field:"thresholds"` // 上线阈值
} }
type NodeIPAddressOperator struct { type NodeIPAddressOperator struct {
@@ -23,6 +26,9 @@ type NodeIPAddressOperator struct {
State interface{} // 状态 State interface{} // 状态
Order interface{} // 排序 Order interface{} // 排序
CanAccess interface{} // 是否可以访问 CanAccess interface{} // 是否可以访问
IsOn interface{} // 是否启用
IsUp interface{} // 是否上线
Thresholds interface{} // 上线阈值
} }
func NewNodeIPAddressOperator() *NodeIPAddressOperator { func NewNodeIPAddressOperator() *NodeIPAddressOperator {

View File

@@ -1 +1,20 @@
package models package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/logs"
)
func (this *NodeIPAddress) DecodeThresholds() []*nodeconfigs.NodeValueThresholdConfig {
var result = []*nodeconfigs.NodeValueThresholdConfig{}
if len(this.Thresholds) == 0 {
return result
}
err := json.Unmarshal([]byte(this.Thresholds), &result)
if err != nil {
// 不处理错误
logs.Error(err)
}
return result
}

View File

@@ -2,6 +2,7 @@ package models
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@@ -33,6 +34,9 @@ var SharedNodeLogDAO *NodeLogDAO
func init() { func init() {
dbs.OnReady(func() { dbs.OnReady(func() {
SharedNodeLogDAO = NewNodeLogDAO() SharedNodeLogDAO = NewNodeLogDAO()
// 设置日志存储
remotelogs.SetDAO(SharedNodeLogDAO)
}) })
} }

View File

@@ -169,6 +169,30 @@ func (this *NSClusterDAO) FindClusterGrantId(tx *dbs.Tx, clusterId int64) (int64
FindInt64Col(0) FindInt64Col(0)
} }
// UpdateRecursion 设置递归DNS
func (this *NSClusterDAO) UpdateRecursion(tx *dbs.Tx, clusterId int64, recursionJSON []byte) error {
err := this.Query(tx).
Pk(clusterId).
Set("recursion", recursionJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, clusterId)
}
// FindClusterRecursion 读取递归DNS配置
func (this *NSClusterDAO) FindClusterRecursion(tx *dbs.Tx, clusterId int64) ([]byte, error) {
recursion, err := this.Query(tx).
Result("recursion").
Pk(clusterId).
FindStringCol("")
if err != nil {
return nil, err
}
return []byte(recursion), nil
}
// NotifyUpdate 通知更改 // NotifyUpdate 通知更改
func (this *NSClusterDAO) NotifyUpdate(tx *dbs.Tx, clusterId int64) error { func (this *NSClusterDAO) NotifyUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleDNS, clusterId, NSNodeTaskTypeConfigChanged) return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleDNS, clusterId, NSNodeTaskTypeConfigChanged)

View File

@@ -9,6 +9,7 @@ type NSCluster struct {
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
AccessLog string `field:"accessLog"` // 访问日志配置 AccessLog string `field:"accessLog"` // 访问日志配置
GrantId uint32 `field:"grantId"` // 授权ID GrantId uint32 `field:"grantId"` // 授权ID
Recursion string `field:"recursion"` // 递归DNS设置
} }
type NSClusterOperator struct { type NSClusterOperator struct {
@@ -19,6 +20,7 @@ type NSClusterOperator struct {
State interface{} // 状态 State interface{} // 状态
AccessLog interface{} // 访问日志配置 AccessLog interface{} // 访问日志配置
GrantId interface{} // 授权ID GrantId interface{} // 授权ID
Recursion interface{} // 递归DNS设置
} }
func NewNSClusterOperator() *NSClusterOperator { func NewNSClusterOperator() *NSClusterOperator {

View File

@@ -417,6 +417,20 @@ func (this *NSNodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*dnsconfigs.
} }
} }
// 递归DNS配置
recursionJSON, err := SharedNSClusterDAO.FindClusterRecursion(tx, int64(node.ClusterId))
if err != nil {
return nil, err
}
if len(recursionJSON) > 0 {
var recursionConfig = &dnsconfigs.RecursionConfig{}
err = json.Unmarshal(recursionJSON, recursionConfig)
if err != nil {
return nil, err
}
config.RecursionConfig = recursionConfig
}
return config, nil return config, nil
} }

View File

@@ -9,6 +9,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -197,7 +198,16 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add
} }
// ComposeOriginConfig 将源站信息转换为配置 // ComposeOriginConfig 将源站信息转换为配置
func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64) (*serverconfigs.OriginConfig, error) { func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap maps.Map) (*serverconfigs.OriginConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(originId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.OriginConfig), nil
}
origin, err := this.FindEnabledOrigin(tx, originId) origin, err := this.FindEnabledOrigin(tx, originId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -313,7 +323,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64) (*serverc
} }
config.CertRef = ref config.CertRef = ref
if ref.CertId > 0 { if ref.CertId > 0 {
certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -325,6 +335,8 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64) (*serverc
// TODO // TODO
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -8,7 +8,7 @@ import (
func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) { func TestOriginServerDAO_ComposeOriginConfig(t *testing.T) {
var tx *dbs.Tx var tx *dbs.Tx
config, err := SharedOriginDAO.ComposeOriginConfig(tx, 1) config, err := SharedOriginDAO.ComposeOriginConfig(tx, 1, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -0,0 +1,269 @@
package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/reporterconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
)
const (
ReportNodeStateEnabled = 1 // 已启用
ReportNodeStateDisabled = 0 // 已禁用
)
type ReportNodeDAO dbs.DAO
func NewReportNodeDAO() *ReportNodeDAO {
return dbs.NewDAO(&ReportNodeDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeReportNodes",
Model: new(ReportNode),
PkName: "id",
},
}).(*ReportNodeDAO)
}
var SharedReportNodeDAO *ReportNodeDAO
func init() {
dbs.OnReady(func() {
SharedReportNodeDAO = NewReportNodeDAO()
})
}
// EnableReportNode 启用条目
func (this *ReportNodeDAO) EnableReportNode(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
Pk(id).
Set("state", ReportNodeStateEnabled).
Update()
return err
}
// DisableReportNode 禁用条目
func (this *ReportNodeDAO) DisableReportNode(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
Pk(id).
Set("state", ReportNodeStateDisabled).
Update()
return err
}
// FindEnabledReportNode 查找启用中的条目
func (this *ReportNodeDAO) FindEnabledReportNode(tx *dbs.Tx, id int64) (*ReportNode, error) {
result, err := this.Query(tx).
Pk(id).
Attr("state", ReportNodeStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*ReportNode), err
}
// FindReportNodeName 根据主键查找名称
func (this *ReportNodeDAO) FindReportNodeName(tx *dbs.Tx, id int64) (string, error) {
return this.Query(tx).
Pk(id).
Result("name").
FindStringCol("")
}
// CreateReportNode 创建终端
func (this *ReportNodeDAO) CreateReportNode(tx *dbs.Tx, name string, location string, isp string, allowIPs []string) (int64, error) {
uniqueId, err := this.GenUniqueId(tx)
if err != nil {
return 0, err
}
secret := rands.String(32)
// 保存API Token
err = SharedApiTokenDAO.CreateAPIToken(tx, uniqueId, secret, nodeconfigs.NodeRoleReport)
if err != nil {
return 0, err
}
op := NewReportNodeOperator()
op.UniqueId = uniqueId
op.Secret = secret
op.Name = name
op.Location = location
op.Isp = isp
if len(allowIPs) > 0 {
allowIPSJSON, err := json.Marshal(allowIPs)
if err != nil {
return 0, err
}
op.AllowIPs = allowIPSJSON
} else {
op.AllowIPs = "[]"
}
op.IsOn = true
op.State = ReportNodeStateEnabled
return this.SaveInt64(tx, op)
}
// UpdateReportNode 修改终端
func (this *ReportNodeDAO) UpdateReportNode(tx *dbs.Tx, nodeId int64, name string, location string, isp string, allowIPs []string, isOn bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
op := NewReportNodeOperator()
op.Id = nodeId
op.Name = name
op.Location = location
op.Isp = isp
if len(allowIPs) > 0 {
allowIPSJSON, err := json.Marshal(allowIPs)
if err != nil {
return err
}
op.AllowIPs = allowIPSJSON
} else {
op.AllowIPs = "[]"
}
op.IsOn = isOn
return this.Save(tx, op)
}
// CountAllEnabledReportNodes 计算终端数量
func (this *ReportNodeDAO) CountAllEnabledReportNodes(tx *dbs.Tx, keyword string) (int64, error) {
var query = this.Query(tx).
State(ReportNodeStateEnabled)
if len(keyword) > 0 {
query.Where("(name LIKE :keyword OR location LIKE :keyword OR isp LIKE :keyword OR allowIPs LIKE :keyword OR (status IS NOT NULL AND JSON_EXTRACT(status, 'ip') LIKE :keyword))")
query.Param("keyword", "%"+keyword+"%")
}
return query.Count()
}
// ListEnabledReportNodes 列出单页终端
func (this *ReportNodeDAO) ListEnabledReportNodes(tx *dbs.Tx, keyword string, offset int64, size int64) (result []*ReportNode, err error) {
var query = this.Query(tx).
State(ReportNodeStateEnabled)
if len(keyword) > 0 {
query.Where(`(
name LIKE :keyword
OR location LIKE :keyword
OR isp LIKE :keyword
OR allowIPs LIKE :keyword
OR (status IS NOT NULL
AND (
JSON_EXTRACT(status, '$.ip') LIKE :keyword)
OR (LENGTH(location)=0 AND JSON_EXTRACT(status, '$.location') LIKE :keyword)
OR (LENGTH(isp)=0 AND JSON_EXTRACT(status, '$.isp') LIKE :keyword)
))`)
query.Param("keyword", "%"+keyword+"%")
}
query.Slice(&result)
_, err = query.Asc("isActive").
Offset(offset).
Limit(size).
DescPk().
FindAll()
return
}
// GenUniqueId 生成唯一ID
func (this *ReportNodeDAO) GenUniqueId(tx *dbs.Tx) (string, error) {
for {
uniqueId := rands.HexString(32)
ok, err := this.Query(tx).
Attr("uniqueId", uniqueId).
Exist()
if err != nil {
return "", err
}
if ok {
continue
}
return uniqueId, nil
}
}
// UpdateNodeActive 修改节点活跃状态
func (this *ReportNodeDAO) UpdateNodeActive(tx *dbs.Tx, nodeId int64, isActive bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
_, err := this.Query(tx).
Pk(nodeId).
Set("isActive", isActive).
Update()
return err
}
// FindNodeActive 检查节点活跃状态
func (this *ReportNodeDAO) FindNodeActive(tx *dbs.Tx, nodeId int64) (bool, error) {
isActive, err := this.Query(tx).
Pk(nodeId).
Result("isActive").
FindIntCol(0)
if err != nil {
return false, err
}
return isActive == 1, nil
}
// FindEnabledNodeIdWithUniqueId 根据唯一ID获取节点ID
func (this *ReportNodeDAO) FindEnabledNodeIdWithUniqueId(tx *dbs.Tx, uniqueId string) (int64, error) {
return this.Query(tx).
Attr("uniqueId", uniqueId).
Attr("state", ReportNodeStateEnabled).
ResultPk().
FindInt64Col(0)
}
// UpdateNodeStatus 更改节点状态
func (this ReportNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, statusJSON []byte) error {
if statusJSON == nil {
return nil
}
_, err := this.Query(tx).
Pk(nodeId).
Set("status", string(statusJSON)).
Update()
return err
}
// ComposeConfig 组合配置
func (this *ReportNodeDAO) ComposeConfig(tx *dbs.Tx, nodeId int64) (*reporterconfigs.NodeConfig, error) {
node, err := this.FindEnabledReportNode(tx, nodeId)
if err != nil {
return nil, err
}
if node == nil {
return nil, nil
}
var config = &reporterconfigs.NodeConfig{
Id: int64(node.Id),
}
return config, nil
}
// FindNodeAllowIPs 查询节点允许的IP
func (this *ReportNodeDAO) FindNodeAllowIPs(tx *dbs.Tx, nodeId int64) ([]string, error) {
node, err := this.Query(tx).
Pk(nodeId).
Result("allowIPs").
Find()
if err != nil {
return nil, err
}
if node == nil {
return nil, nil
}
return node.(*ReportNode).DecodeAllowIPs(), nil
}

View File

@@ -0,0 +1,6 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,36 @@
package models
// ReportNode 连通性报告终端
type ReportNode struct {
Id uint32 `field:"id"` // ID
UniqueId string `field:"uniqueId"` // 唯一ID
Secret string `field:"secret"` // 密钥
IsOn uint8 `field:"isOn"` // 是否启用
Name string `field:"name"` // 名称
Location string `field:"location"` // 所在区域
Isp string `field:"isp"` // 网络服务商
AllowIPs string `field:"allowIPs"` // 允许的IP
IsActive uint8 `field:"isActive"` // 是否活跃
Status string `field:"status"` // 状态
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
}
type ReportNodeOperator struct {
Id interface{} // ID
UniqueId interface{} // 唯一ID
Secret interface{} // 密钥
IsOn interface{} // 是否启用
Name interface{} // 名称
Location interface{} // 所在区域
Isp interface{} // 网络服务商
AllowIPs interface{} // 允许的IP
IsActive interface{} // 是否活跃
Status interface{} // 状态
State interface{} // 状态
CreatedAt interface{} // 创建时间
}
func NewReportNodeOperator() *ReportNodeOperator {
return &ReportNodeOperator{}
}

View File

@@ -0,0 +1,12 @@
package models
import "encoding/json"
func (this *ReportNode) DecodeAllowIPs() []string {
var result = []string{}
if len(this.AllowIPs) > 0 {
// 忽略错误
_ = json.Unmarshal([]byte(this.AllowIPs), &result)
}
return result
}

View File

@@ -0,0 +1,98 @@
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"time"
)
type ReportResultDAO dbs.DAO
func NewReportResultDAO() *ReportResultDAO {
return dbs.NewDAO(&ReportResultDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeReportResults",
Model: new(ReportResult),
PkName: "id",
},
}).(*ReportResultDAO)
}
var SharedReportResultDAO *ReportResultDAO
func init() {
dbs.OnReady(func() {
SharedReportResultDAO = NewReportResultDAO()
})
}
// UpdateResult 创建结果
func (this *ReportResultDAO) UpdateResult(tx *dbs.Tx, taskType string, targetId int64, targetDesc string, reportNodeId int64, isOk bool, costMs float64, errString string) error {
var countUp interface{} = 0
var countDown interface{} = 0
if isOk {
countUp = dbs.SQL("countUp+1")
} else {
countDown = dbs.SQL("countDown+1")
}
return this.Query(tx).
InsertOrUpdateQuickly(maps.Map{
"type": taskType,
"targetId": targetId,
"targetDesc": targetDesc,
"updatedAt": time.Now().Unix(),
"reportNodeId": reportNodeId,
"isOk": isOk,
"costMs": costMs,
"error": errString,
"countUp": countUp,
"countDown": countDown,
}, maps.Map{
"targetDesc": targetDesc,
"updatedAt": time.Now().Unix(),
"isOk": isOk,
"costMs": costMs,
"error": errString,
"countUp": countUp,
"countDown": countDown,
})
}
// CountAllResults 计算结果数量
func (this *ReportResultDAO) CountAllResults(tx *dbs.Tx, reportNodeId int64, okState configutils.BoolState) (int64, error) {
var query = this.Query(tx).
Attr("reportNodeId", reportNodeId)
switch okState {
case configutils.BoolStateYes:
query.Attr("isOk", 1)
case configutils.BoolStateNo:
query.Attr("isOk", 0)
}
return query.
Count()
}
// ListResults 列出单页结果
func (this *ReportResultDAO) ListResults(tx *dbs.Tx, reportNodeId int64, okState configutils.BoolState, offset int64, size int64) (result []*ReportResult, err error) {
var query = this.Query(tx).
Attr("reportNodeId", reportNodeId)
switch okState {
case configutils.BoolStateYes:
query.Attr("isOk", 1)
case configutils.BoolStateNo:
query.Attr("isOk", 0)
}
_, err = query.
Attr("reportNodeId", reportNodeId).
Offset(offset).
Limit(size).
Desc("targetId").
Slice(&result).
FindAll()
return
}

View File

@@ -0,0 +1,6 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,34 @@
package models
// ReportResult 连通性监控结果
type ReportResult struct {
Id uint64 `field:"id"` // ID
Type string `field:"type"` // 对象类型
TargetId uint64 `field:"targetId"` // 对象ID
TargetDesc string `field:"targetDesc"` // 对象描述
UpdatedAt uint64 `field:"updatedAt"` // 更新时间
ReportNodeId uint32 `field:"reportNodeId"` // 监控节点ID
IsOk uint8 `field:"isOk"` // 是否可连接
CostMs float64 `field:"costMs"` // 单次连接花费的时间
Error string `field:"error"` // 产生的错误信息
CountUp uint32 `field:"countUp"` // 连续上线次数
CountDown uint32 `field:"countDown"` // 连续下线次数
}
type ReportResultOperator struct {
Id interface{} // ID
Type interface{} // 对象类型
TargetId interface{} // 对象ID
TargetDesc interface{} // 对象描述
UpdatedAt interface{} // 更新时间
ReportNodeId interface{} // 监控节点ID
IsOk interface{} // 是否可连接
CostMs interface{} // 单次连接花费的时间
Error interface{} // 产生的错误信息
CountUp interface{} // 连续上线次数
CountDown interface{} // 连续下线次数
}
func NewReportResultOperator() *ReportResultOperator {
return &ReportResultOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -80,7 +80,16 @@ func (this *ReverseProxyDAO) FindEnabledReverseProxy(tx *dbs.Tx, id int64) (*Rev
} }
// ComposeReverseProxyConfig 根据ID组合配置 // ComposeReverseProxyConfig 根据ID组合配置
func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyId int64) (*serverconfigs.ReverseProxyConfig, error) { func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyId int64, cacheMap maps.Map) (*serverconfigs.ReverseProxyConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(reverseProxyId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*serverconfigs.ReverseProxyConfig), nil
}
reverseProxy, err := this.FindEnabledReverseProxy(tx, reverseProxyId) reverseProxy, err := this.FindEnabledReverseProxy(tx, reverseProxyId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -113,7 +122,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI
return nil, err return nil, err
} }
for _, ref := range originRefs { for _, ref := range originRefs {
originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, ref.OriginId) originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, ref.OriginId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -130,7 +139,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI
return nil, err return nil, err
} }
for _, originConfig := range originRefs { for _, originConfig := range originRefs {
originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, originConfig.OriginId) originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, originConfig.OriginId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -181,6 +190,8 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI
config.IdleTimeout = idleTimeout config.IdleTimeout = idleTimeout
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }

View File

@@ -8,7 +8,7 @@ import (
func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) { func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) {
var tx *dbs.Tx var tx *dbs.Tx
config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, 1) config, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, 1, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,17 +1,14 @@
package models package models
import ( import (
"crypto/md5"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
@@ -310,49 +307,6 @@ func (this *ServerDAO) UpdateServerIsOn(tx *dbs.Tx, serverId int64, isOn bool) e
return nil return nil
} }
// UpdateServerConfig 修改服务配置
func (this *ServerDAO) UpdateServerConfig(tx *dbs.Tx, serverId int64, configJSON []byte, updateMd5 bool) (isChanged bool, err error) {
if serverId <= 0 {
return false, errors.New("serverId should not be smaller than 0")
}
// 查询以前的md5
oldConfigMd5, err := this.Query(tx).
Pk(serverId).
Result("configMd5").
FindStringCol("")
if err != nil {
return false, err
}
globalConfig, err := SharedSysSettingDAO.ReadSetting(tx, systemconfigs.SettingCodeServerGlobalConfig)
if err != nil {
return false, err
}
m := md5.New()
_, _ = m.Write(configJSON) // 当前服务配置
_, _ = m.Write(globalConfig) // 全局配置
h := m.Sum(nil)
newConfigMd5 := fmt.Sprintf("%x", h)
// 如果配置相同则不更改
if oldConfigMd5 == newConfigMd5 {
return false, nil
}
op := NewServerOperator()
op.Id = serverId
op.Config = JSONBytes(configJSON)
op.Version = dbs.SQL("version+1")
if updateMd5 {
op.ConfigMd5 = newConfigMd5
}
err = this.Save(tx, op)
return true, err
}
// UpdateServerHTTP 修改HTTP配置 // UpdateServerHTTP 修改HTTP配置
func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byte) error { func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byte) error {
if serverId <= 0 { if serverId <= 0 {
@@ -785,15 +739,19 @@ func (this *ServerDAO) ComposeServerConfigWithServerId(tx *dbs.Tx, serverId int6
if server == nil { if server == nil {
return nil, ErrNotFound return nil, ErrNotFound
} }
return this.ComposeServerConfig(tx, server) return this.ComposeServerConfig(tx, server, nil)
} }
// ComposeServerConfig 构造服务的Config // ComposeServerConfig 构造服务的Config
func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverconfigs.ServerConfig, error) { func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, cacheMap maps.Map) (*serverconfigs.ServerConfig, error) {
if server == nil { if server == nil {
return nil, ErrNotFound return nil, ErrNotFound
} }
if cacheMap == nil {
cacheMap = maps.Map{}
}
config := &serverconfigs.ServerConfig{} config := &serverconfigs.ServerConfig{}
config.Id = int64(server.Id) config.Id = int64(server.Id)
config.ClusterId = int64(server.ClusterId) config.ClusterId = int64(server.ClusterId)
@@ -814,12 +772,12 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
// CNAME // CNAME
if server.ClusterId > 0 && len(server.DnsName) > 0 { if server.ClusterId > 0 && len(server.DnsName) > 0 {
clusterDNS, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, int64(server.ClusterId)) clusterDNS, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, int64(server.ClusterId), cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if clusterDNS != nil && clusterDNS.DnsDomainId > 0 { if clusterDNS != nil && clusterDNS.DnsDomainId > 0 {
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, int64(clusterDNS.DnsDomainId)) domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, int64(clusterDNS.DnsDomainId), cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -850,7 +808,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
// SSL // SSL
if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 {
sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, httpsConfig.SSLPolicyRef.SSLPolicyId) sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, httpsConfig.SSLPolicyRef.SSLPolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -882,7 +840,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
// SSL // SSL
if tlsConfig.SSLPolicyRef != nil { if tlsConfig.SSLPolicyRef != nil {
sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, tlsConfig.SSLPolicyRef.SSLPolicyId) sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -916,7 +874,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
// Web // Web
if server.WebId > 0 { if server.WebId > 0 {
webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(server.WebId)) webConfig, err := SharedHTTPWebDAO.ComposeWebConfig(tx, int64(server.WebId), cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -934,7 +892,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
} }
config.ReverseProxyRef = reverseProxyRef config.ReverseProxyRef = reverseProxyRef
reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId) reverseProxyConfig, err := SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -945,7 +903,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
// WAF策略 // WAF策略
clusterId := int64(server.ClusterId) clusterId := int64(server.ClusterId)
httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId) httpFirewallPolicyId, err := SharedNodeClusterDAO.FindClusterHTTPFirewallPolicyId(tx, clusterId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -954,7 +912,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
} }
// 缓存策略 // 缓存策略
httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId) httpCachePolicyId, err := SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -965,19 +923,6 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server) (*serverc
return config, nil return config, nil
} }
// RenewServerConfig 更新服务的Config配置
func (this *ServerDAO) RenewServerConfig(tx *dbs.Tx, serverId int64, updateMd5 bool) (isChanged bool, err error) {
serverConfig, err := this.ComposeServerConfigWithServerId(tx, serverId)
if err != nil {
return false, err
}
data, err := json.Marshal(serverConfig)
if err != nil {
return false, err
}
return this.UpdateServerConfig(tx, serverId, data, updateMd5)
}
// FindReverseProxyRef 根据条件获取反向代理配置 // FindReverseProxyRef 根据条件获取反向代理配置
func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) { func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) {
reverseProxy, err := this.Query(tx). reverseProxy, err := this.Query(tx).
@@ -1422,14 +1367,48 @@ func (this *ServerDAO) FindLatestServers(tx *dbs.Tx, size int64) (result []*Serv
return return
} }
// NotifyUpdate 同步集群 // FindFirstHTTPOrHTTPSPortWithClusterId 获取集群中第一个HTTP或者HTTPS端口
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, clusterId int64) (int, error) {
// 更新配置 one, _, err := this.Query(tx).
_, err := this.RenewServerConfig(tx, serverId, true) Result("JSON_EXTRACT(http, '$.listen[*].portRange') AS httpPort, JSON_EXTRACT(https, '$.listen[*].portRange') AS httpsPort").
if err != nil && err != ErrNotFound { Attr("clusterId", clusterId).
return err State(ServerStateEnabled).
Attr("isOn", 1).
Where("((JSON_CONTAINS(http, :queryJSON) AND JSON_EXTRACT(http, '$.listen[*].portRange') IS NOT NULL) OR (JSON_CONTAINS(https, :queryJSON) AND JSON_EXTRACT(https, '$.listen[*].portRange') IS NOT NULL))").
Param("queryJSON", "{\"isOn\":true}").
FindOne()
if err != nil {
return 0, err
}
httpPortString := one.GetString("httpPort")
if len(httpPortString) > 0 {
var ports = []string{}
err = json.Unmarshal([]byte(httpPortString), &ports)
if err != nil {
return 0, err
}
if len(ports) > 0 {
return types.Int(ports[0]), nil
}
} }
httpsPortString := one.GetString("httpsPort")
if len(httpsPortString) > 0 {
var ports = []string{}
err = json.Unmarshal([]byte(httpsPortString), &ports)
if err != nil {
return 0, err
}
if len(ports) > 0 {
return types.Int(ports[0]), nil
}
}
return 0, nil
}
// NotifyUpdate 同步集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
// 创建任务 // 创建任务
clusterId, err := this.FindServerClusterId(tx, serverId) clusterId, err := this.FindServerClusterId(tx, serverId)
if err != nil { if err != nil {
@@ -1453,7 +1432,7 @@ func (this *ServerDAO) NotifyDNSUpdate(tx *dbs.Tx, serverId int64) error {
if clusterId <= 0 { if clusterId <= 0 {
return nil return nil
} }
dnsInfo, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId) dnsInfo, err := SharedNodeClusterDAO.FindClusterDNSInfo(tx, clusterId, nil)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -41,10 +41,7 @@ func TestServerDAO_UpdateServerConfig(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = SharedServerDAO.UpdateServerConfig(tx, 1, configJSON, false) t.Log(string(configJSON))
if err != nil {
t.Fatal(err)
}
t.Log("ok") t.Log("ok")
} }
@@ -147,3 +144,15 @@ func TestServerDAO_FindAllEnabledServersWithNode(t *testing.T) {
t.Log("serverId:", server.Id, "clusterId:", server.ClusterId) t.Log("serverId:", server.Id, "clusterId:", server.ClusterId)
} }
} }
func BenchmarkServerDAO_CountAllEnabledServers(b *testing.B) {
SharedServerDAO = NewServerDAO()
for i := 0; i < b.N; i++ {
result, err := SharedServerDAO.CountAllEnabledServers(nil)
if err != nil {
b.Fatal(err)
}
_ = result
}
}

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"time" "time"
@@ -38,12 +39,12 @@ func init() {
}) })
} }
// 初始化 // Init 初始化
func (this *SSLCertDAO) Init() { func (this *SSLCertDAO) Init() {
_ = this.DAOObject.Init() _ = this.DAOObject.Init()
} }
// 启用条目 // EnableSSLCert 启用条目
func (this *SSLCertDAO) EnableSSLCert(tx *dbs.Tx, id int64) error { func (this *SSLCertDAO) EnableSSLCert(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -52,7 +53,7 @@ func (this *SSLCertDAO) EnableSSLCert(tx *dbs.Tx, id int64) error {
return err return err
} }
// 禁用条目 // DisableSSLCert 禁用条目
func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, certId int64) error { func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, certId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(certId). Pk(certId).
@@ -64,7 +65,7 @@ func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, certId int64) error {
return this.NotifyUpdate(tx, certId) return this.NotifyUpdate(tx, certId)
} }
// 查找启用中的条目 // FindEnabledSSLCert 查找启用中的条目
func (this *SSLCertDAO) FindEnabledSSLCert(tx *dbs.Tx, id int64) (*SSLCert, error) { func (this *SSLCertDAO) FindEnabledSSLCert(tx *dbs.Tx, id int64) (*SSLCert, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(id).
@@ -76,7 +77,7 @@ func (this *SSLCertDAO) FindEnabledSSLCert(tx *dbs.Tx, id int64) (*SSLCert, erro
return result.(*SSLCert), err return result.(*SSLCert), err
} }
// 根据主键查找名称 // FindSSLCertName 根据主键查找名称
func (this *SSLCertDAO) FindSSLCertName(tx *dbs.Tx, id int64) (string, error) { func (this *SSLCertDAO) FindSSLCertName(tx *dbs.Tx, id int64) (string, error) {
return this.Query(tx). return this.Query(tx).
Pk(id). Pk(id).
@@ -84,7 +85,7 @@ func (this *SSLCertDAO) FindSSLCertName(tx *dbs.Tx, id int64) (string, error) {
FindStringCol("") FindStringCol("")
} }
// 创建证书 // CreateCert 创建证书
func (this *SSLCertDAO) CreateCert(tx *dbs.Tx, adminId int64, userId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) (int64, error) { func (this *SSLCertDAO) CreateCert(tx *dbs.Tx, adminId int64, userId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) (int64, error) {
op := NewSSLCertOperator() op := NewSSLCertOperator()
op.AdminId = adminId op.AdminId = adminId
@@ -119,7 +120,7 @@ func (this *SSLCertDAO) CreateCert(tx *dbs.Tx, adminId int64, userId int64, isOn
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// 修改证书 // UpdateCert 修改证书
func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, certId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) error { func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, certId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) error {
if certId <= 0 { if certId <= 0 {
return errors.New("invalid certId") return errors.New("invalid certId")
@@ -162,8 +163,17 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, certId int64, isOn bool, name str
return this.NotifyUpdate(tx, certId) return this.NotifyUpdate(tx, certId)
} }
// 组合配置 // ComposeCertConfig 组合配置
func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64) (*sslconfigs.SSLCertConfig, error) { func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64, cacheMap maps.Map) (*sslconfigs.SSLCertConfig, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(certId)
var cache = cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*sslconfigs.SSLCertConfig), nil
}
cert, err := this.FindEnabledSSLCert(tx, certId) cert, err := this.FindEnabledSSLCert(tx, certId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -203,10 +213,12 @@ func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64) (*sslconfigs
config.CommonNames = commonNames config.CommonNames = commonNames
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }
// 计算符合条件的证书数量 // CountCerts 计算符合条件的证书数量
func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) { func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) {
query := this.Query(tx). query := this.Query(tx).
State(SSLCertStateEnabled) State(SSLCertStateEnabled)
@@ -236,7 +248,7 @@ func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isEx
return query.Count() return query.Count()
} }
// 列出符合条件的证书 // ListCertIds 列出符合条件的证书
func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) { func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) {
query := this.Query(tx). query := this.Query(tx).
State(SSLCertStateEnabled) State(SSLCertStateEnabled)
@@ -281,7 +293,7 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE
return result, nil return result, nil
} }
// 设置证书的ACME信息 // UpdateCertACME 设置证书的ACME信息
func (this *SSLCertDAO) UpdateCertACME(tx *dbs.Tx, certId int64, acmeTaskId int64) error { func (this *SSLCertDAO) UpdateCertACME(tx *dbs.Tx, certId int64, acmeTaskId int64) error {
if certId <= 0 { if certId <= 0 {
return errors.New("invalid certId") return errors.New("invalid certId")
@@ -294,7 +306,7 @@ func (this *SSLCertDAO) UpdateCertACME(tx *dbs.Tx, certId int64, acmeTaskId int6
return err return err
} }
// 查找需要自动更新的任务 // FindAllExpiringCerts 查找需要自动更新的任务
// 这里我们只返回有限的字段以节省内存 // 这里我们只返回有限的字段以节省内存
func (this *SSLCertDAO) FindAllExpiringCerts(tx *dbs.Tx, days int) (result []*SSLCert, err error) { func (this *SSLCertDAO) FindAllExpiringCerts(tx *dbs.Tx, days int) (result []*SSLCert, err error) {
if days < 0 { if days < 0 {
@@ -314,7 +326,7 @@ func (this *SSLCertDAO) FindAllExpiringCerts(tx *dbs.Tx, days int) (result []*SS
return return
} }
// 设置当前证书事件通知时间 // UpdateCertNotifiedAt 设置当前证书事件通知时间
func (this *SSLCertDAO) UpdateCertNotifiedAt(tx *dbs.Tx, certId int64) error { func (this *SSLCertDAO) UpdateCertNotifiedAt(tx *dbs.Tx, certId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(certId). Pk(certId).
@@ -323,7 +335,7 @@ func (this *SSLCertDAO) UpdateCertNotifiedAt(tx *dbs.Tx, certId int64) error {
return err return err
} }
// 检查用户权限 // CheckUserCert 检查用户权限
func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) error { func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) error {
if certId <= 0 || userId <= 0 { if certId <= 0 || userId <= 0 {
return errors.New("not found") return errors.New("not found")
@@ -342,7 +354,7 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er
return nil return nil
} }
// 通知更新 // NotifyUpdate 通知更新
func (this *SSLCertDAO) NotifyUpdate(tx *dbs.Tx, certId int64) error { func (this *SSLCertDAO) NotifyUpdate(tx *dbs.Tx, certId int64) error {
policyIds, err := SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, certId) policyIds, err := SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, certId)
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"strconv" "strconv"
) )
@@ -37,12 +38,12 @@ func init() {
}) })
} }
// 初始化 // Init 初始化
func (this *SSLPolicyDAO) Init() { func (this *SSLPolicyDAO) Init() {
_ = this.DAOObject.Init() _ = this.DAOObject.Init()
} }
// 启用条目 // EnableSSLPolicy 启用条目
func (this *SSLPolicyDAO) EnableSSLPolicy(tx *dbs.Tx, id int64) error { func (this *SSLPolicyDAO) EnableSSLPolicy(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -51,7 +52,7 @@ func (this *SSLPolicyDAO) EnableSSLPolicy(tx *dbs.Tx, id int64) error {
return err return err
} }
// 禁用条目 // DisableSSLPolicy 禁用条目
func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, policyId int64) error { func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, policyId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(policyId). Pk(policyId).
@@ -63,7 +64,7 @@ func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, policyId int64) error {
return this.NotifyUpdate(tx, policyId) return this.NotifyUpdate(tx, policyId)
} }
// 查找启用中的条目 // FindEnabledSSLPolicy 查找启用中的条目
func (this *SSLPolicyDAO) FindEnabledSSLPolicy(tx *dbs.Tx, id int64) (*SSLPolicy, error) { func (this *SSLPolicyDAO) FindEnabledSSLPolicy(tx *dbs.Tx, id int64) (*SSLPolicy, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(id).
@@ -75,8 +76,18 @@ func (this *SSLPolicyDAO) FindEnabledSSLPolicy(tx *dbs.Tx, id int64) (*SSLPolicy
return result.(*SSLPolicy), err return result.(*SSLPolicy), err
} }
// 组合配置 // ComposePolicyConfig 组合配置
func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*sslconfigs.SSLPolicy, error) { func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64, cacheMap maps.Map) (*sslconfigs.SSLPolicy, error) {
if cacheMap == nil {
cacheMap = maps.Map{}
}
var cacheKey = this.Table + ":config:" + types.String(policyId)
var cacheConfig = cacheMap.Get(cacheKey)
if cacheConfig != nil {
return cacheConfig.(*sslconfigs.SSLPolicy), nil
}
policy, err := this.FindEnabledSSLPolicy(tx, policyId) policy, err := this.FindEnabledSSLPolicy(tx, policyId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -100,7 +111,7 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*sslc
} }
if len(refs) > 0 { if len(refs) > 0 {
for _, ref := range refs { for _, ref := range refs {
certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -122,7 +133,7 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*sslc
} }
if len(refs) > 0 { if len(refs) > 0 {
for _, ref := range refs { for _, ref := range refs {
certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId) certConfig, err := SharedSSLCertDAO.ComposeCertConfig(tx, ref.CertId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -156,10 +167,12 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64) (*sslc
config.HSTS = hstsConfig config.HSTS = hstsConfig
} }
cacheMap[cacheKey] = config
return config, nil return config, nil
} }
// 查询使用单个证书的所有策略ID // FindAllEnabledPolicyIdsWithCertId 查询使用单个证书的所有策略ID
func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId int64) (policyIds []int64, err error) { func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId int64) (policyIds []int64, err error) {
if certId <= 0 { if certId <= 0 {
return return
@@ -180,7 +193,7 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId i
return policyIds, nil return policyIds, nil
} }
// 创建Policy // CreatePolicy 创建Policy
func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) {
op := NewSSLPolicyOperator() op := NewSSLPolicyOperator()
op.State = SSLPolicyStateEnabled op.State = SSLPolicyStateEnabled
@@ -218,8 +231,7 @@ func (this *SSLPolicyDAO) CreatePolicy(tx *dbs.Tx, adminId int64, userId int64,
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// 修改Policy // UpdatePolicy 修改Policy
// 创建Policy
func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error { func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error {
if policyId <= 0 { if policyId <= 0 {
return errors.New("invalid policyId") return errors.New("invalid policyId")
@@ -259,7 +271,7 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled
return this.NotifyUpdate(tx, policyId) return this.NotifyUpdate(tx, policyId)
} }
// 检查是否为用户所属策略 // CheckUserPolicy 检查是否为用户所属策略
func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error { func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error {
if policyId <= 0 || userId <= 0 { if policyId <= 0 || userId <= 0 {
return errors.New("not found") return errors.New("not found")
@@ -278,7 +290,7 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int
return nil return nil
} }
// 通知更新 // NotifyUpdate 通知更新
func (this *SSLPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { func (this *SSLPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error {
serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, []int64{policyId}) serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, []int64{policyId})
if err != nil { if err != nil {

View File

@@ -118,7 +118,7 @@ func (this *NodeTrafficHourlyStatDAO) FindHourlyStatsWithClusterId(tx *dbs.Tx, c
} }
// FindTopNodeStats 取得一定时间内的节点排行数据 // FindTopNodeStats 取得一定时间内的节点排行数据
func (this *NodeTrafficHourlyStatDAO) FindTopNodeStats(tx *dbs.Tx, role string, hourFrom string, hourTo string) (result []*NodeTrafficHourlyStat, err error) { func (this *NodeTrafficHourlyStatDAO) FindTopNodeStats(tx *dbs.Tx, role string, hourFrom string, hourTo string, size int64) (result []*NodeTrafficHourlyStat, err error) {
// TODO 节点如果已经被删除,则忽略 // TODO 节点如果已经被删除,则忽略
_, err = this.Query(tx). _, err = this.Query(tx).
Attr("role", role). Attr("role", role).
@@ -126,13 +126,14 @@ func (this *NodeTrafficHourlyStatDAO) FindTopNodeStats(tx *dbs.Tx, role string,
Result("nodeId, SUM(bytes) AS bytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes"). Result("nodeId, SUM(bytes) AS bytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes").
Group("nodeId"). Group("nodeId").
Desc("countRequests"). Desc("countRequests").
Limit(size).
Slice(&result). Slice(&result).
FindAll() FindAll()
return return
} }
// FindTopNodeStatsWithClusterId 取得集群一定时间内的节点排行数据 // FindTopNodeStatsWithClusterId 取得集群一定时间内的节点排行数据
func (this *NodeTrafficHourlyStatDAO) FindTopNodeStatsWithClusterId(tx *dbs.Tx, role string, clusterId int64, hourFrom string, hourTo string) (result []*NodeTrafficHourlyStat, err error) { func (this *NodeTrafficHourlyStatDAO) FindTopNodeStatsWithClusterId(tx *dbs.Tx, role string, clusterId int64, hourFrom string, hourTo string, size int64) (result []*NodeTrafficHourlyStat, err error) {
// TODO 节点如果已经被删除,则忽略 // TODO 节点如果已经被删除,则忽略
_, err = this.Query(tx). _, err = this.Query(tx).
Attr("role", role). Attr("role", role).
@@ -141,6 +142,7 @@ func (this *NodeTrafficHourlyStatDAO) FindTopNodeStatsWithClusterId(tx *dbs.Tx,
Result("nodeId, SUM(bytes) AS bytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes"). Result("nodeId, SUM(bytes) AS bytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes").
Group("nodeId"). Group("nodeId").
Desc("countRequests"). Desc("countRequests").
Limit(size).
Slice(&result). Slice(&result).
FindAll() FindAll()
return return

View File

@@ -3,9 +3,10 @@ package models
import ( import (
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/maps"
) )
// 解析HTTP配置 // DecodeHTTP 解析HTTP配置
func (this *UserNode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) { func (this *UserNode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
if !IsNotNull(this.Http) { if !IsNotNull(this.Http) {
return nil, nil return nil, nil
@@ -24,8 +25,8 @@ func (this *UserNode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
return config, nil return config, nil
} }
// 解析HTTPS配置 // DecodeHTTPS 解析HTTPS配置
func (this *UserNode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { func (this *UserNode) DecodeHTTPS(cacheMap maps.Map) (*serverconfigs.HTTPSProtocolConfig, error) {
if !IsNotNull(this.Https) { if !IsNotNull(this.Https) {
return nil, nil return nil, nil
} }
@@ -43,7 +44,7 @@ func (this *UserNode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error)
if config.SSLPolicyRef != nil { if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 { if policyId > 0 {
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(nil, policyId) sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(nil, policyId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -61,7 +62,7 @@ func (this *UserNode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error)
return config, nil return config, nil
} }
// 解析访问地址 // DecodeAccessAddrs 解析访问地址
func (this *UserNode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) { func (this *UserNode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) {
if !IsNotNull(this.AccessAddrs) { if !IsNotNull(this.AccessAddrs) {
return nil, nil return nil, nil
@@ -81,7 +82,7 @@ func (this *UserNode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig
return addrConfigs, nil return addrConfigs, nil
} }
// 解析访问地址,并返回字符串形式 // DecodeAccessAddrStrings 解析访问地址,并返回字符串形式
func (this *UserNode) DecodeAccessAddrStrings() ([]string, error) { func (this *UserNode) DecodeAccessAddrStrings() ([]string, error) {
addrs, err := this.DecodeAccessAddrs() addrs, err := this.DecodeAccessAddrs()
if err != nil { if err != nil {

View File

@@ -31,6 +31,34 @@ func (this *AliDNSProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *AliDNSProvider) GetDomains() (domains []string, err error) {
pageNumber := 1
size := 100
for {
req := alidns.CreateDescribeDomainsRequest()
req.PageNumber = requests.NewInteger(pageNumber)
req.PageSize = requests.NewInteger(size)
resp := alidns.CreateDescribeDomainsResponse()
err = this.doAPI(req, resp)
if err != nil {
return nil, err
}
for _, domain := range resp.Domains.Domain {
domains = append(domains, domain.DomainName)
}
pageNumber++
if int64((pageNumber-1)*size) >= resp.TotalCount {
break
}
}
return
}
// GetRecords 获取域名列表 // GetRecords 获取域名列表
func (this *AliDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *AliDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
pageNumber := 1 pageNumber := 1

View File

@@ -11,6 +11,14 @@ import (
"testing" "testing"
) )
func TestAliDNSProvider_GetDomains(t *testing.T) {
provider, err := testAliDNSProvider()
if err != nil {
t.Fatal(err)
}
t.Log(provider.GetDomains())
}
func TestAliDNSProvider_GetRecords(t *testing.T) { func TestAliDNSProvider_GetRecords(t *testing.T) {
provider, err := testAliDNSProvider() provider, err := testAliDNSProvider()
if err != nil { if err != nil {

View File

@@ -59,6 +59,21 @@ func (this *CloudFlareProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *CloudFlareProvider) GetDomains() (domains []string, err error) {
resp := new(cloudflare.ZonesResponse)
err = this.doAPI(http.MethodGet, "zones", map[string]string{}, nil, resp)
if err != nil {
return nil, err
}
for _, zone := range resp.Result {
domains = append(domains, zone.Name)
}
return
}
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
func (this *CloudFlareProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *CloudFlareProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
zoneId, err := this.findZoneIdWithDomain(domain) zoneId, err := this.findZoneIdWithDomain(domain)

View File

@@ -12,6 +12,14 @@ import (
"testing" "testing"
) )
func TestCloudFlareProvider_GetDomains(t *testing.T) {
provider, err := testCloudFlareProvider()
if err != nil {
t.Fatal(err)
}
t.Log(provider.GetDomains())
}
func TestCloudFlareProvider_GetRecords(t *testing.T) { func TestCloudFlareProvider_GetRecords(t *testing.T) {
provider, err := testCloudFlareProvider() provider, err := testCloudFlareProvider()
if err != nil { if err != nil {

View File

@@ -49,6 +49,16 @@ func (this *CustomHTTPProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *CustomHTTPProvider) GetDomains() (domains []string, err error) {
resp, err := this.post(maps.Map{})
if err != nil {
return nil, err
}
err = json.Unmarshal(resp, &domains)
return
}
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
func (this *CustomHTTPProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *CustomHTTPProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
resp, err := this.post(maps.Map{ resp, err := this.post(maps.Map{

View File

@@ -7,6 +7,22 @@ import (
"testing" "testing"
) )
func TestCustomHTTPProvider_GetDomains(t *testing.T) {
provider := CustomHTTPProvider{}
err := provider.Auth(maps.Map{
"url": "http://127.0.0.1:2345/dns",
"secret": "123456",
})
if err != nil {
t.Fatal(err)
}
domains, err := provider.GetDomains()
if err != nil {
t.Fatal(err)
}
t.Log(domains)
}
func TestCustomHTTPProvider_AddRecord(t *testing.T) { func TestCustomHTTPProvider_AddRecord(t *testing.T) {
provider := CustomHTTPProvider{} provider := CustomHTTPProvider{}
err := provider.Auth(maps.Map{ err := provider.Auth(maps.Map{

View File

@@ -35,6 +35,40 @@ func (this *DNSPodProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *DNSPodProvider) GetDomains() (domains []string, err error) {
offset := 0
size := 100
for {
domainsResp, err := this.post("/Domain.list", map[string]string{
"offset": numberutils.FormatInt(offset),
"length": numberutils.FormatInt(size),
})
if err != nil {
return nil, err
}
offset += size
domainsSlice := domainsResp.GetSlice("domains")
if len(domainsSlice) == 0 {
break
}
for _, domain := range domainsSlice {
domainMap := maps.NewMap(domain)
domains = append(domains, domainMap.GetString("name"))
}
// 检查是否到头
info := domainsResp.GetMap("info")
recordTotal := info.GetInt("record_total")
if offset >= recordTotal {
break
}
}
return
}
// GetRecords 获取域名列表 // GetRecords 获取域名列表
func (this *DNSPodProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *DNSPodProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
offset := 0 offset := 0

View File

@@ -9,6 +9,18 @@ import (
"testing" "testing"
) )
func TestDNSPodProvider_GetDomains(t *testing.T) {
provider, err := testDNSPodProvider()
if err != nil {
t.Fatal(err)
}
domains, err := provider.GetDomains()
if err != nil {
t.Fatal(err)
}
t.Log(domains)
}
func TestDNSPodProvider_GetRoutes(t *testing.T) { func TestDNSPodProvider_GetRoutes(t *testing.T) {
provider, err := testDNSPodProvider() provider, err := testDNSPodProvider()
if err != nil { if err != nil {

View File

@@ -55,6 +55,21 @@ func (this *HuaweiDNSProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *HuaweiDNSProvider) GetDomains() (domains []string, err error) {
var resp = new(huaweidns.ZonesResponse)
err = this.doAPI(http.MethodGet, "/v2/zones", map[string]string{}, nil, resp)
if err != nil {
return nil, err
}
for _, zone := range resp.Zones {
zone.Name = strings.TrimSuffix(zone.Name, ".")
domains = append(domains, zone.Name)
}
return
}
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
func (this *HuaweiDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *HuaweiDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
zoneId, err := this.findZoneIdWithDomain(domain) zoneId, err := this.findZoneIdWithDomain(domain)

View File

@@ -11,6 +11,18 @@ import (
"testing" "testing"
) )
func TestHuaweiDNSProvider_GetDomains(t *testing.T) {
provider, err := testHuaweiDNSProvider()
if err != nil {
t.Fatal(err)
}
domains, err := provider.GetDomains()
if err != nil {
t.Fatal(err)
}
t.Log("domains:", domains)
}
func TestHuaweiDNSProvider_GetRecords(t *testing.T) { func TestHuaweiDNSProvider_GetRecords(t *testing.T) {
provider, err := testHuaweiDNSProvider() provider, err := testHuaweiDNSProvider()
if err != nil { if err != nil {

View File

@@ -10,6 +10,9 @@ type ProviderInterface interface {
// Auth 认证 // Auth 认证
Auth(params maps.Map) error Auth(params maps.Map) error
// GetDomains 获取所有域名列表
GetDomains() (domains []string, err error)
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
GetRecords(domain string) (records []*dnstypes.Record, err error) GetRecords(domain string) (records []*dnstypes.Record, err error)

View File

@@ -34,6 +34,19 @@ func (this *LocalEdgeDNSProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *LocalEdgeDNSProvider) GetDomains() (domains []string, err error) {
var tx *dbs.Tx
domainOnes, err := nameservers.SharedNSDomainDAO.ListEnabledDomains(tx, this.clusterId, 0, "", 0, 1000)
if err != nil {
return nil, err
}
for _, domain := range domainOnes {
domains = append(domains, domain.Name)
}
return
}
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
var tx *dbs.Tx var tx *dbs.Tx

View File

@@ -13,6 +13,24 @@ import (
const testClusterId = 7 const testClusterId = 7
func TestLocalEdgeDNSProvider_GetDomains(t *testing.T) {
dbs.NotifyReady()
provider := &dnsclients.LocalEdgeDNSProvider{}
err := provider.Auth(maps.Map{
"clusterId": testClusterId,
})
if err != nil {
t.Fatal(err)
}
domains, err := provider.GetDomains()
if err != nil {
t.Fatal(err)
}
t.Log("domains:", domains)
}
func TestLocalEdgeDNSProvider_GetRecords(t *testing.T) { func TestLocalEdgeDNSProvider_GetRecords(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()

View File

@@ -16,6 +16,12 @@ func (this *UserEdgeDNSProvider) Auth(params maps.Map) error {
return nil return nil
} }
// GetDomains 获取所有域名列表
func (this *UserEdgeDNSProvider) GetDomains() (domains []string, err error) {
// TODO
return
}
// GetRecords 获取域名解析记录列表 // GetRecords 获取域名解析记录列表
func (this *UserEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { func (this *UserEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
// TODO // TODO

View File

@@ -356,7 +356,7 @@ func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) {
} }
// HTTPS // HTTPS
httpsConfig, err := apiNode.DecodeHTTPS(nil) httpsConfig, err := apiNode.DecodeHTTPS(nil, nil)
if err != nil { if err != nil {
remotelogs.Error("API_NODE", "decode https config: "+err.Error()) remotelogs.Error("API_NODE", "decode https config: "+err.Error())
return return
@@ -433,7 +433,7 @@ func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) {
} }
// Rest HTTPS // Rest HTTPS
restHTTPSConfig, err := apiNode.DecodeRestHTTPS(nil) restHTTPSConfig, err := apiNode.DecodeRestHTTPS(nil, nil)
if err != nil { if err != nil {
remotelogs.Error("API_NODE", "decode REST https config: "+err.Error()) remotelogs.Error("API_NODE", "decode REST https config: "+err.Error())
return return

View File

@@ -48,6 +48,11 @@ func (this *APINode) registerServices(server *grpc.Server) {
pb.RegisterNodeIPAddressServiceServer(server, instance) pb.RegisterNodeIPAddressServiceServer(server, instance)
this.rest(instance) this.rest(instance)
} }
{
instance := this.serviceInstance(&services.NodeIPAddressLogService{}).(*services.NodeIPAddressLogService)
pb.RegisterNodeIPAddressLogServiceServer(server, instance)
this.rest(instance)
}
{ {
instance := this.serviceInstance(&services.APINodeService{}).(*services.APINodeService) instance := this.serviceInstance(&services.APINodeService{}).(*services.APINodeService)
pb.RegisterAPINodeServiceServer(server, instance) pb.RegisterAPINodeServiceServer(server, instance)

View File

@@ -0,0 +1,12 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package remotelogs
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
)
type DAOInterface interface {
CreateLog(tx *dbs.Tx, nodeRole nodeconfigs.NodeRole, nodeId int64, serverId int64, originId int64, level string, tag string, description string, createdAt int64) error
}

View File

@@ -3,7 +3,6 @@ package remotelogs
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/configs" "github.com/TeaOSLab/EdgeAPI/internal/configs"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
@@ -11,6 +10,7 @@ import (
) )
var logChan = make(chan *pb.NodeLog, 1024) var logChan = make(chan *pb.NodeLog, 1024)
var sharedDAO DAOInterface
func init() { func init() {
// 定期上传日志 // 定期上传日志
@@ -25,7 +25,7 @@ func init() {
}() }()
} }
// 打印普通信息 // Println 打印普通信息
func Println(tag string, description string) { func Println(tag string, description string) {
logs.Println("[" + tag + "]" + description) logs.Println("[" + tag + "]" + description)
@@ -48,7 +48,7 @@ func Println(tag string, description string) {
} }
} }
// 打印警告信息 // Warn 打印警告信息
func Warn(tag string, description string) { func Warn(tag string, description string) {
logs.Println("[" + tag + "]" + description) logs.Println("[" + tag + "]" + description)
@@ -71,7 +71,7 @@ func Warn(tag string, description string) {
} }
} }
// 打印错误信息 // Error 打印错误信息
func Error(tag string, description string) { func Error(tag string, description string) {
logs.Println("[" + tag + "]" + description) logs.Println("[" + tag + "]" + description)
@@ -94,13 +94,22 @@ func Error(tag string, description string) {
} }
} }
// SetDAO 设置存储接口
func SetDAO(dao DAOInterface) {
sharedDAO = dao
}
// 上传日志 // 上传日志
func uploadLogs() error { func uploadLogs() error {
if sharedDAO == nil {
return nil
}
Loop: Loop:
for { for {
select { select {
case log := <-logChan: case log := <-logChan:
err := models.SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleAPI, log.NodeId, log.ServerId, log.OriginId, log.Level, log.Tag, log.Description, log.CreatedAt) err := sharedDAO.CreateLog(nil, nodeconfigs.NodeRoleAPI, log.NodeId, log.ServerId, log.OriginId, log.Level, log.Tag, log.Description, log.CreatedAt)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,9 +3,11 @@ package nameservers
import ( import (
"context" "context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services" "github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
) )
// NSAccessLogService 访问日志相关服务 // NSAccessLogService 访问日志相关服务
@@ -61,6 +63,27 @@ func (this *NSAccessLogService) ListNSAccessLogs(ctx context.Context, req *pb.Li
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 线路
if len(a.NsRouteCodes) > 0 {
for _, routeCode := range a.NsRouteCodes {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(nil, routeCode)
if err != nil {
return nil, err
}
if route != nil {
a.NsRoutes = append(a.NsRoutes, &pb.NSRoute{
Id: types.Int64(route.Id),
IsOn: route.IsOn == 1,
Name: route.Name,
Code: routeCode,
NsCluster: nil,
NsDomain: nil,
})
}
}
}
result = append(result, a) result = append(result, a)
} }

View File

@@ -4,8 +4,10 @@ package nameservers
import ( import (
"context" "context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services" "github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
) )
@@ -175,3 +177,42 @@ func (this *NSClusterService) FindAllEnabledNSClusters(ctx context.Context, req
} }
return &pb.FindAllEnabledNSClustersResponse{NsClusters: pbClusters}, nil return &pb.FindAllEnabledNSClustersResponse{NsClusters: pbClusters}, nil
} }
// UpdateNSClusterRecursionConfig 设置递归DNS配置
func (this *NSClusterService) UpdateNSClusterRecursionConfig(ctx context.Context, req *pb.UpdateNSClusterRecursionConfigRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
// 校验配置
var config = &dnsconfigs.RecursionConfig{}
err = json.Unmarshal(req.RecursionJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateRecursion(tx, req.NsClusterId, req.RecursionJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterRecursionConfig 读取递归DNS配置
func (this *NSClusterService) FindNSClusterRecursionConfig(ctx context.Context, req *pb.FindNSClusterRecursionConfigRequest) (*pb.FindNSClusterRecursionConfigResponse, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
var tx = this.NullTx()
recursion, err := models.SharedNSClusterDAO.FindClusterRecursion(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterRecursionConfigResponse{
RecursionJSON: recursion,
}, nil
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string" stringutil "github.com/iwind/TeaGo/utils/string"
"io"
"path/filepath" "path/filepath"
) )
@@ -466,7 +467,7 @@ func (this *NSNodeService) DownloadNSNodeInstallationFile(ctx context.Context, r
} }
data, offset, err := file.Read(req.ChunkOffset) data, offset, err := file.Read(req.ChunkOffset)
if err != nil { if err != nil && err != io.EOF {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,41 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package reporters
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/dbs"
"google.golang.org/grpc/peer"
"net"
)
// 校验客户端IP
func validateClient(tx *dbs.Tx, nodeId int64, ctx context.Context) error {
allowIPs, err := models.SharedReportNodeDAO.FindNodeAllowIPs(tx, nodeId)
if err != nil {
return err
}
if len(allowIPs) == 0 {
return nil
}
p, ok := peer.FromContext(ctx)
if ok {
host, _, _ := net.SplitHostPort(p.Addr.String())
if len(host) > 0 {
for _, ip := range allowIPs {
r, err := shared.ParseIPRange(ip)
if err == nil && r != nil {
if r.Contains(host) {
return nil
}
}
}
}
}
return errors.New("client was not allowed")
}

View File

@@ -474,6 +474,15 @@ func (this *AdminService) ComposeAdminDashboard(ctx context.Context, req *pb.Com
var tx = this.NullTx() var tx = this.NullTx()
// 默认集群
nodeClusters, err := models.SharedNodeClusterDAO.ListEnabledClusters(tx, "", 0, 1)
if err != nil {
return nil, err
}
if len(nodeClusters) > 0 {
result.DefaultNodeClusterId = int64(nodeClusters[0].Id)
}
// 集群数 // 集群数
countClusters, err := models.SharedNodeClusterDAO.CountAllEnabledClusters(tx, "") countClusters, err := models.SharedNodeClusterDAO.CountAllEnabledClusters(tx, "")
if err != nil { if err != nil {
@@ -620,8 +629,12 @@ func (this *AdminService) ComposeAdminDashboard(ctx context.Context, req *pb.Com
// API节点升级信息 // API节点升级信息
{ {
var apiVersion = req.ApiVersion
if len(apiVersion) == 0 {
apiVersion = teaconst.Version
}
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{ upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: teaconst.Version, NewVersion: apiVersion,
} }
countNodes, err := models.SharedAPINodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion) countNodes, err := models.SharedAPINodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
if err != nil { if err != nil {
@@ -662,7 +675,7 @@ func (this *AdminService) ComposeAdminDashboard(ctx context.Context, req *pb.Com
// 节点排行 // 节点排行
if isPlus { if isPlus {
topNodeStats, err := stats.SharedNodeTrafficHourlyStatDAO.FindTopNodeStats(tx, "node", hourFrom, hourTo) topNodeStats, err := stats.SharedNodeTrafficHourlyStatDAO.FindTopNodeStats(tx, "node", hourFrom, hourTo, 10)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -162,7 +162,7 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
return &pb.ListEnabledAPINodesResponse{Nodes: result}, nil return &pb.ListEnabledAPINodesResponse{Nodes: result}, nil
} }
// 根据ID查找节点 // FindEnabledAPINode 根据ID查找节点
func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.FindEnabledAPINodeRequest) (*pb.FindEnabledAPINodeResponse, error) { func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.FindEnabledAPINodeRequest) (*pb.FindEnabledAPINodeResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0) _, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {

View File

@@ -1,88 +0,0 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
// +build plus
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/authority"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
plusutils "github.com/TeaOSLab/EdgePlus/pkg/utils"
)
// AuthorityKeyService 版本认证
type AuthorityKeyService struct {
BaseService
}
// UpdateAuthorityKey 设置Key
func (this *AuthorityKeyService) UpdateAuthorityKey(ctx context.Context, req *pb.UpdateAuthorityKeyRequest) (*pb.RPCSuccess, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAuthority)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = authority.SharedAuthorityKeyDAO.UpdateKey(tx, req.Value, req.DayFrom, req.DayTo, req.Hostname, req.MacAddresses, req.Company)
if err != nil {
return nil, err
}
return this.Success()
}
// ReadAuthorityKey 读取Key
func (this *AuthorityKeyService) ReadAuthorityKey(ctx context.Context, req *pb.ReadAuthorityKeyRequest) (*pb.ReadAuthorityKeyResponse, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeMonitor, rpcutils.UserTypeProvider, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
key, err := authority.SharedAuthorityKeyDAO.ReadKey(tx)
if err != nil {
return nil, err
}
if key == nil {
return &pb.ReadAuthorityKeyResponse{AuthorityKey: nil}, nil
}
if len(key.Value) == 0 {
return &pb.ReadAuthorityKeyResponse{AuthorityKey: nil}, nil
}
m, err := plusutils.Decode([]byte(key.Value))
if err != nil {
return nil, err
}
macAddresses := []string{}
if len(key.MacAddresses) > 0 {
err = json.Unmarshal([]byte(key.MacAddresses), &macAddresses)
if err != nil {
return nil, err
}
}
return &pb.ReadAuthorityKeyResponse{AuthorityKey: &pb.AuthorityKey{
Value: key.Value,
DayFrom: m.GetString("dayFrom"),
DayTo: m.GetString("dayTo"),
Hostname: key.Hostname,
MacAddresses: macAddresses,
Company: key.Company,
UpdatedAt: int64(key.UpdatedAt),
}}, nil
}
// ResetAuthorityKey 重置Key
func (this *AuthorityKeyService) ResetAuthorityKey(ctx context.Context, req *pb.ResetAuthorityKeyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
err = authority.SharedAuthorityKeyDAO.ResetKey(nil)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -194,6 +194,8 @@ func (this *BaseService) ValidateNodeId(ctx context.Context, roles ...rpcutils.U
nodeIntId, err = models.SharedMonitorNodeDAO.FindEnabledMonitorNodeIdWithUniqueId(nil, nodeId) nodeIntId, err = models.SharedMonitorNodeDAO.FindEnabledMonitorNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeDNS: case rpcutils.UserTypeDNS:
nodeIntId, err = models.SharedNSNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId) nodeIntId, err = models.SharedNSNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeReport:
nodeIntId, err = models.SharedReportNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeAuthority: case rpcutils.UserTypeAuthority:
nodeIntId, err = authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNodeIdWithUniqueId(nil, nodeId) nodeIntId, err = authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNodeIdWithUniqueId(nil, nodeId)
default: default:

View File

@@ -109,6 +109,10 @@ func (this *DBNodeService) ListEnabledDBNodes(ctx context.Context, req *pb.ListE
if err != nil { if err != nil {
status.Error = err.Error() status.Error = err.Error()
} else { } else {
// 版本
version, _ := db.FindCol(0, "SELECT VERSION()")
status.Version = types.String(version)
one, err := db.FindOne("SELECT SUM(DATA_LENGTH+INDEX_LENGTH) AS size FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name()) one, err := db.FindOne("SELECT SUM(DATA_LENGTH+INDEX_LENGTH) AS size FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil { if err != nil {
status.Error = err.Error() status.Error = err.Error()
@@ -294,3 +298,49 @@ func (this *DBNodeService) TruncateDBNodeTable(ctx context.Context, req *pb.Trun
} }
return this.Success() return this.Success()
} }
// CheckDBNodeStatus 检查数据库节点状态
func (this *DBNodeService) CheckDBNodeStatus(ctx context.Context, req *pb.CheckDBNodeStatusRequest) (*pb.CheckDBNodeStatusResponse, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.CheckDBNodeStatusResponse{DbNodeStatus: nil}, nil
}
status := &pb.DBNodeStatus{}
// 是否能够连接
if node.IsOn == 1 {
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
status.Error = err.Error()
} else {
// 版本
version, _ := db.FindCol(0, "SELECT VERSION()")
status.Version = types.String(version)
one, err := db.FindOne("SELECT SUM(DATA_LENGTH+INDEX_LENGTH) AS size FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
status.Error = err.Error()
_ = db.Close()
} else if one == nil {
status.Error = "unable to read size from database server"
_ = db.Close()
} else {
status.IsOk = true
status.Size = one.GetInt64("size")
_ = db.Close()
}
}
}
return &pb.CheckDBNodeStatusResponse{DbNodeStatus: status}, nil
}

View File

@@ -3,6 +3,7 @@ package services
import ( import (
"context" "context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns/dnsutils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
) )
@@ -39,7 +40,7 @@ func (this *DNSService) FindAllDNSIssues(ctx context.Context, req *pb.FindAllDNS
clusters = []*models.NodeCluster{cluster} clusters = []*models.NodeCluster{cluster}
} }
for _, cluster := range clusters { for _, cluster := range clusters {
issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(tx, cluster) issues, err := dnsutils.CheckClusterDNS(tx, cluster)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns/dnsutils"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
@@ -12,6 +13,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"net" "net"
@@ -137,7 +139,7 @@ func (this *DNSDomainService) FindEnabledDNSDomain(ctx context.Context, req *pb.
tx := this.NullTx() tx := this.NullTx()
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -145,7 +147,7 @@ func (this *DNSDomainService) FindEnabledDNSDomain(ctx context.Context, req *pb.
return &pb.FindEnabledDNSDomainResponse{DnsDomain: nil}, nil return &pb.FindEnabledDNSDomainResponse{DnsDomain: nil}, nil
} }
pbDomain, err := this.convertDomainToPB(domain) pbDomain, err := this.convertDomainToPB(tx, domain)
return &pb.FindEnabledDNSDomainResponse{DnsDomain: pbDomain}, nil return &pb.FindEnabledDNSDomainResponse{DnsDomain: pbDomain}, nil
} }
@@ -159,7 +161,7 @@ func (this *DNSDomainService) FindEnabledBasicDNSDomain(ctx context.Context, req
tx := this.NullTx() tx := this.NullTx()
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,7 +211,7 @@ func (this *DNSDomainService) FindAllEnabledDNSDomainsWithDNSProviderId(ctx cont
result := []*pb.DNSDomain{} result := []*pb.DNSDomain{}
for _, domain := range domains { for _, domain := range domains {
pbDomain, err := this.convertDomainToPB(domain) pbDomain, err := this.convertDomainToPB(tx, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -300,9 +302,14 @@ func (this *DNSDomainService) ExistAvailableDomains(ctx context.Context, req *pb
} }
// 转换域名信息 // 转换域名信息
func (this *DNSDomainService) convertDomainToPB(domain *dns.DNSDomain) (*pb.DNSDomain, error) { func (this *DNSDomainService) convertDomainToPB(tx *dbs.Tx, domain *dns.DNSDomain) (*pb.DNSDomain, error) {
domainId := int64(domain.Id) domainId := int64(domain.Id)
defaultRoute, err := dnsutils.FindDefaultDomainRoute(tx, domain)
if err != nil {
return nil, err
}
records := []*dnstypes.Record{} records := []*dnstypes.Record{}
if len(domain.Records) > 0 && domain.Records != "null" { if len(domain.Records) > 0 && domain.Records != "null" {
err := json.Unmarshal([]byte(domain.Records), &records) err := json.Unmarshal([]byte(domain.Records), &records)
@@ -319,8 +326,6 @@ func (this *DNSDomainService) convertDomainToPB(domain *dns.DNSDomain) (*pb.DNSD
countServerRecords := 0 countServerRecords := 0
serversChanged := false serversChanged := false
tx := this.NullTx()
// 检查是否所有的集群都已经被解析 // 检查是否所有的集群都已经被解析
clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, domainId) clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, domainId)
if err != nil { if err != nil {
@@ -330,7 +335,8 @@ func (this *DNSDomainService) convertDomainToPB(domain *dns.DNSDomain) (*pb.DNSD
countAllNodes1 := int64(0) countAllNodes1 := int64(0)
countAllServers1 := int64(0) countAllServers1 := int64(0)
for _, cluster := range clusters { for _, cluster := range clusters {
_, nodeRecords, serverRecords, countAllNodes, countAllServers, nodesChanged2, serversChanged2, err := this.findClusterDNSChanges(cluster, records, domain.Name)
_, nodeRecords, serverRecords, countAllNodes, countAllServers, nodesChanged2, serversChanged2, err := this.findClusterDNSChanges(cluster, records, domain.Name, defaultRoute)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -364,6 +370,7 @@ func (this *DNSDomainService) convertDomainToPB(domain *dns.DNSDomain) (*pb.DNSD
ProviderId: int64(domain.ProviderId), ProviderId: int64(domain.ProviderId),
Name: domain.Name, Name: domain.Name,
IsOn: domain.IsOn == 1, IsOn: domain.IsOn == 1,
IsUp: domain.IsUp == 1,
DataUpdatedAt: int64(domain.DataUpdatedAt), DataUpdatedAt: int64(domain.DataUpdatedAt),
CountNodeRecords: int64(countNodeRecords), CountNodeRecords: int64(countNodeRecords),
NodesChanged: nodesChanged, NodesChanged: nodesChanged,
@@ -388,7 +395,7 @@ func (this *DNSDomainService) convertRecordToPB(record *dnstypes.Record) *pb.DNS
} }
// 检查集群节点变化 // 检查集群节点变化
func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) { func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) {
clusterId := int64(cluster.Id) clusterId := int64(cluster.Id)
clusterDnsName := cluster.DnsName clusterDnsName := cluster.DnsName
clusterDomain := clusterDnsName + "." + domainName clusterDomain := clusterDnsName + "." + domainName
@@ -413,7 +420,7 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster,
// 新增的节点域名 // 新增的节点域名
nodeKeys := []string{} nodeKeys := []string{}
for _, node := range nodes { for _, node := range nodes {
ipAddresses, err := models.SharedNodeIPAddressDAO.FindNodeAccessIPAddresses(tx, int64(node.Id), nodeconfigs.NodeRoleNode) ipAddresses, err := models.SharedNodeIPAddressDAO.FindNodeAccessAndUpIPAddresses(tx, int64(node.Id), nodeconfigs.NodeRoleNode)
if err != nil { if err != nil {
return nil, nil, nil, 0, 0, false, false, err return nil, nil, nil, 0, 0, false, false, err
} }
@@ -425,7 +432,12 @@ func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster,
return nil, nil, nil, 0, 0, false, false, err return nil, nil, nil, 0, 0, false, false, err
} }
if len(routeCodes) == 0 { if len(routeCodes) == 0 {
continue // 默认线路
if len(defaultRoute) > 0 {
routeCodes = []string{defaultRoute}
} else {
continue
}
} }
for _, route := range routeCodes { for _, route := range routeCodes {
for _, ipAddress := range ipAddresses { for _, ipAddress := range ipAddresses {
@@ -564,7 +576,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (
} }
// 域名信息 // 域名信息
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId) domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -616,7 +628,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (
// 检查集群设置 // 检查集群设置
for _, cluster := range clusters { for _, cluster := range clusters {
issues, err := models.SharedNodeClusterDAO.CheckClusterDNS(tx, cluster) issues, err := dnsutils.CheckClusterDNS(tx, cluster)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -642,7 +654,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (
// 对比变化 // 对比变化
allChanges := []maps.Map{} allChanges := []maps.Map{}
for _, cluster := range clusters { for _, cluster := range clusters {
changes, _, _, _, _, _, _, err := this.findClusterDNSChanges(cluster, records, domainName) changes, _, _, _, _, _, _, err := this.findClusterDNSChanges(cluster, records, domainName, manager.DefaultRoute())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -693,7 +705,7 @@ func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (
}, nil }, nil
} }
// 检查域名是否在记录中 // ExistDNSDomainRecord 检查域名是否在记录中
func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.ExistDNSDomainRecordRequest) (*pb.ExistDNSDomainRecordResponse, error) { func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.ExistDNSDomainRecordRequest) (*pb.ExistDNSDomainRecordResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0) _, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
@@ -708,3 +720,84 @@ func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.
} }
return &pb.ExistDNSDomainRecordResponse{IsOk: isOk}, nil return &pb.ExistDNSDomainRecordResponse{IsOk: isOk}, nil
} }
// SyncDNSDomainsFromProvider 从服务商同步域名
func (this *DNSDomainService) SyncDNSDomainsFromProvider(ctx context.Context, req *pb.SyncDNSDomainsFromProviderRequest) (*pb.SyncDNSDomainsFromProviderResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
tx := this.NullTx()
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
if provider == nil {
return nil, errors.New("can not find provider")
}
// 下线不存在的域名
oldDomains, err := dns.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
dnsProvider := dnsclients.FindProvider(provider.Type)
if dnsProvider == nil {
return nil, errors.New("provider type '" + provider.Type + "' is not supported yet")
}
params, err := provider.DecodeAPIParams()
if err != nil {
return nil, errors.New("decode params failed: " + err.Error())
}
err = dnsProvider.Auth(params)
if err != nil {
return nil, errors.New("auth failed: " + err.Error())
}
domainNames, err := dnsProvider.GetDomains()
if err != nil {
return nil, err
}
var hasChanges = false
// 创建或上线域名
for _, domainName := range domainNames {
domain, err := dns.SharedDNSDomainDAO.FindEnabledDomainWithName(tx, req.DnsProviderId, domainName)
if err != nil {
return nil, err
}
if domain == nil {
_, err = dns.SharedDNSDomainDAO.CreateDomain(tx, 0, 0, req.DnsProviderId, domainName)
if err != nil {
return nil, err
}
hasChanges = true
} else if domain.IsUp == 0 {
err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(domain.Id), true)
if err != nil {
return nil, err
}
hasChanges = true
}
}
// 将老的域名置为下线
for _, oldDomain := range oldDomains {
var domainName = oldDomain.Name
if oldDomain.IsUp == 1 && !lists.ContainsString(domainNames, domainName) {
err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(oldDomain.Id), false)
if err != nil {
return nil, err
}
hasChanges = true
}
}
return &pb.SyncDNSDomainsFromProviderResponse{
HasChanges: hasChanges,
}, nil
}

Some files were not shown because too many files have changed in this diff Show More