Compare commits

...

187 Commits

Author SHA1 Message Date
刘祥超
a987837904 版本修改为1.3.7 2024-04-29 23:12:29 +08:00
刘祥超
4bcad223ca 版本号变更为1.3.6 2024-04-22 10:56:04 +08:00
刘祥超
b9e5005d05 升级相关依赖 2024-04-21 20:54:16 +08:00
刘祥超
fd9bdee6be 修复单例安装程序编译问题 2024-04-21 20:12:29 +08:00
刘祥超
4460956de6 在IP名单中搜索IP时同时也搜索“所有IP”类型的IP 2024-04-21 10:48:25 +08:00
刘祥超
1923c2706a 版本号修改为1.3.5 2024-04-20 22:46:56 +08:00
刘祥超
1470ec2b65 修复ttlcache可能缺失回收数据的问题 2024-04-20 22:22:24 +08:00
刘祥超
62fc9bcc68 增加edgeMtricSumStats表中total长度 2024-04-19 16:13:32 +08:00
刘祥超
8a08df8593 更新相关依赖库 2024-04-17 22:23:46 +08:00
刘祥超
bc7a1f37b6 版本号修改为1.3.4.4 2024-04-16 14:14:59 +08:00
刘祥超
f8a01e4639 版本号修改为1.3.4.3 2024-04-15 09:26:18 +08:00
刘祥超
8e4c00ef31 删除一直未实现的Unix协议相关内容 2024-04-14 17:11:58 +08:00
刘祥超
dc894828e0 源站增加快速停用/启用功能 2024-04-14 16:27:27 +08:00
刘祥超
7f6d8ba7b4 修复4位版本号导致无法自动升级SQL的问题 2024-04-14 11:47:10 +08:00
刘祥超
b8b56db83c 进程重启时,自动保存未保存的带宽统计数据到本地文件,以便于在重启后恢复 2024-04-13 17:14:58 +08:00
刘祥超
17f0821945 简化IP名单中创建IP操作/支持IP以CIDR方式显示 2024-04-13 16:48:24 +08:00
刘祥超
b5436a6e57 网站看板数据中增加当日独立IP和当日流量 2024-04-12 18:51:42 +08:00
刘祥超
46fe2d8369 优化套餐变更后网站限流状态 2024-04-12 11:35:52 +08:00
刘祥超
ff429c270d 节点配置中增加集群的密钥信息 2024-04-11 14:19:08 +08:00
刘祥超
92de19e359 版本号修改为1.3.4.2 2024-04-09 10:05:40 +08:00
刘祥超
4121c81a0a 将版本号修改为1.3.4.1 2024-04-08 14:57:05 +08:00
刘祥超
f639ab8342 WAF策略增加JSCookie动作选项 2024-04-07 14:21:01 +08:00
刘祥超
f17a8ab1d0 标记edgeIPItems中两个字段为弃用 2024-04-07 11:26:56 +08:00
刘祥超
12f677eb12 IP名单中的“全局封锁名单”文字改为“系统黑名单” 2024-04-06 16:13:17 +08:00
刘祥超
7595bdeb6b 用户系统增加IP检查功能 2024-04-06 15:23:19 +08:00
刘祥超
fc223af3f0 IP检查也支持范围搜索 2024-04-06 15:15:33 +08:00
刘祥超
ebe3632f07 支持搜索IPv6范围 2024-04-06 14:55:51 +08:00
刘祥超
930babc010 IP名单搜索IP时同时搜索范围 2024-04-06 10:31:03 +08:00
刘祥超
255e3a61e6 更好地支持IPv6 2024-04-06 10:21:52 +08:00
刘祥超
52155a23ab 集群设置增加自动硬盘TRIM选项 2024-04-04 17:04:53 +08:00
刘祥超
200f244c0c “磁盘”文字改为“硬盘” 2024-04-04 16:49:17 +08:00
刘祥超
ab5d7539ce 节点上传指标数据时只上传变更的部分 2024-04-03 08:15:20 +08:00
刘祥超
3e79840fe6 使用MMAP提升缓存读取性能 2024-03-29 18:32:31 +08:00
刘祥超
d03455e3b0 将版本号修改为1.3.4 2024-03-24 20:08:27 +08:00
刘祥超
af1cb14110 提升登录SESSION安全性 2024-03-18 12:43:13 +08:00
刘祥超
0feffa755e 节点SSH密码和私钥均以掩码方式显示 2024-03-18 10:51:47 +08:00
刘祥超
7cfbe2e473 DNS服务商中的密钥数据以掩码方式显示 2024-03-18 10:20:22 +08:00
刘祥超
7f63dc4565 查找省份对应ID时,自动尝试省略省、区之类的后缀 2024-03-15 15:08:05 +08:00
刘祥超
e90424f80a 翻译部分英文地名 2024-03-15 15:07:07 +08:00
刘祥超
6271125296 省份表增加线路字段 2024-03-14 20:42:13 +08:00
刘祥超
44ac4b83c5 智能DNS中国家/地区线路下支持省/州的细分 2024-03-14 20:12:04 +08:00
刘祥超
c75e2c55c6 优化代码 2024-03-10 16:26:03 +08:00
刘祥超
51a3029c09 在缓存任务键值中增加集群信息,以便于调试问题 2024-03-10 11:26:28 +08:00
刘祥超
580341d397 优化systemd服务配置 2024-03-08 19:00:27 +08:00
刘祥超
fb4bad0731 单例应用设置数据库自动清理 2024-03-04 11:32:47 +08:00
刘祥超
70efff2e6b 优化实例安装脚本 2024-03-03 17:14:29 +08:00
刘祥超
97c76ef22f 优化单例应用安装程序 2024-03-02 20:51:13 +08:00
刘祥超
e763095756 修复部分API返回格式错误 2024-02-24 09:52:47 +08:00
刘祥超
b7dc2738e2 增加单体应用初始化标识 2024-01-29 18:56:37 +08:00
刘祥超
3db826b578 增加通过管理员用户名查找管理员信息的API 2024-01-29 18:55:04 +08:00
刘祥超
dc8975e374 版本号修改为1.3.3.1 2024-01-29 17:58:36 +08:00
刘祥超
c0cbd7c607 实现单体实例安装工具 2024-01-29 17:57:01 +08:00
刘祥超
4d9f404bb0 优化SQL升级代码 2024-01-29 10:22:27 +08:00
刘祥超
06bb61804b 优化编译脚本 2024-01-22 18:51:22 +08:00
刘祥超
32c1442878 增加修改节点停用/启用状态API 2024-01-21 17:43:20 +08:00
刘祥超
b99652801d 版本号修改为1.3.3 2024-01-21 16:57:37 +08:00
刘祥超
be565a98b9 查询集群列表API增加ID排序 2024-01-21 16:57:17 +08:00
刘祥超
5195a380db WAF策略增加显示页面动作默认设置 2024-01-20 16:19:11 +08:00
刘祥超
8dbbabb0e8 修改版本号为1.3.2.2 2024-01-16 20:59:18 +08:00
刘祥超
bec4500746 版本号修改为1.3.2.1 2024-01-15 08:40:23 +08:00
刘祥超
66a31f599d 网站设置增加HLS加密功能(商业版本 2024-01-14 20:36:47 +08:00
刘祥超
534cfb2180 套餐增加文件最大上传尺寸设置 2024-01-13 19:32:48 +08:00
刘祥超
a9dc20ffbd 优化API错误提示 2024-01-12 12:11:13 +08:00
刘祥超
7f20ad32b6 调用API时找不到服务或方法时也提示JSON,防止小白开发者不知道如何获取响应状态 2024-01-12 11:51:06 +08:00
刘祥超
a3c0b43bc4 添加快捷添加和删除网站源站API 2024-01-12 11:50:10 +08:00
刘祥超
1f2c9a6b3a 增加删除一组网站API 2024-01-11 19:06:25 +08:00
刘祥超
194b0ec184 套餐可以设置带宽限制 2024-01-11 15:21:00 +08:00
刘祥超
c94895a7c4 增加用户系统文章相关管理 2024-01-09 10:20:52 +08:00
刘祥超
22d15bcb27 华为云DNS线路增加一组"运营商_地区“线路 2023-12-25 09:05:06 +08:00
刘祥超
361fb9b868 升级程序中的1.3.1.x改为1.3.2 2023-12-24 17:40:40 +08:00
刘祥超
2d675f4281 源码编译版本增加节点数限制 2023-12-24 11:28:41 +08:00
刘祥超
e19bbdf891 版本号修改为1.3.2 2023-12-24 11:14:39 +08:00
刘祥超
d48c0a2328 增加列出IP名单中的IP ID列表的API 2023-12-24 10:51:29 +08:00
刘祥超
a70b20cf13 增加请求脚本审核机制 2023-12-23 20:56:11 +08:00
刘祥超
eb83017ed4 修复一处编译错误 2023-12-22 16:46:37 +08:00
刘祥超
98ba31174b 套餐增加简介信息 2023-12-21 15:09:50 +08:00
刘祥超
aa28e84507 增加若干功能代号 2023-12-20 17:34:54 +08:00
刘祥超
da8fe918fe 更新SQL 2023-12-20 15:54:36 +08:00
刘祥超
2b26bed97c 增加若干API 2023-12-20 15:08:05 +08:00
刘祥超
5e50518bd9 限制ACME错误消息长度 2023-12-19 20:05:34 +08:00
刘祥超
e49db916f8 套餐增加Websocket连接数限制 2023-12-19 14:56:44 +08:00
刘祥超
16083fd0d7 增加多个台湾地区区县地址 2023-12-18 09:43:49 +08:00
刘祥超
e0e2729fef 版本号修改为1.3.1.2 2023-12-18 08:51:04 +08:00
刘祥超
9b95042936 缓存设置中可以设置缓存主域名,用来复用多域名下的缓存 2023-12-13 18:34:57 +08:00
刘祥超
44d45c53a1 增加保存管理员语言选择的API 2023-12-12 22:40:06 +08:00
刘祥超
c5fb340eb7 自动升级WAF策略中SQL注入检测和XSS注入检测 2023-12-12 17:15:21 +08:00
刘祥超
cbb61d2f0e 读取用户信息时同时返回语言设置 2023-12-12 11:49:05 +08:00
刘祥超
a143714370 WebP策略变化时只更新相关配置 2023-12-11 11:08:19 +08:00
刘祥超
0e1a98c5d8 将部分MB、GB...改成MiB、GiB... 2023-12-03 11:32:09 +08:00
刘祥超
707a9f8caf 优化代码 2023-11-29 16:58:11 +08:00
刘祥超
da391f565b 创建集群时默认生成子域名 2023-11-27 11:28:31 +08:00
刘祥超
78f396129f 阿里云线路显示完整的线路名称 2023-11-26 20:16:13 +08:00
刘祥超
e8b620aa1e 提交SQL 2023-11-24 10:24:29 +08:00
刘祥超
1019370f37 提交go.sum 2023-11-24 10:21:50 +08:00
刘祥超
cd7cff4f9c 修复一处编译错误 2023-11-24 10:20:16 +08:00
刘祥超
2888634fb0 将版本号修改为1.3.1 2023-11-23 17:24:09 +08:00
刘祥超
94defc3e0c 优化SSH认证sudo设置 2023-11-23 16:12:52 +08:00
刘祥超
9089ed2657 DNSPod改名为腾讯云DNSPod/DNSPod 支持腾讯云API密钥 2023-11-23 15:15:11 +08:00
刘祥超
b60bb5f6da 提交SQL 2023-11-19 09:11:07 +08:00
刘祥超
ff4ea41963 节点配置中增加节点IP信息 2023-11-18 12:09:47 +08:00
刘祥超
b7dccad449 实现用户系统手机号码绑定和登录(商业版) 2023-11-17 11:51:29 +08:00
刘祥超
7fead214d4 更新SQL 2023-11-15 19:10:18 +08:00
刘祥超
d9590ec605 创建反向代理时默认不自动重试50X/源站支持404内容自动重试其他源站 2023-11-15 19:05:43 +08:00
刘祥超
20b936580f 版本号修改为1.3.0 2023-11-14 14:47:32 +08:00
刘祥超
b7b43bc31f 限制访问日志中域名能写入的最大长度 2023-11-13 17:12:11 +08:00
刘祥超
6fd4f26755 自定义页面增加例外URL和限制URL设置 2023-11-13 10:46:12 +08:00
刘祥超
f15d114708 自定义页面增加“跳转URL”功能 2023-11-10 16:36:09 +08:00
刘祥超
fc24195b55 增加访问日志中域名长度 2023-11-10 09:56:17 +08:00
刘祥超
ed5de57244 去除一处多余的日志 2023-11-07 17:34:09 +08:00
刘祥超
4ce347738f 修复无法将OSS源站修改为http/https源站的问题 2023-11-04 08:28:08 +08:00
刘祥超
f6e725781c 优化节点阈值设置 2023-11-03 11:20:47 +08:00
刘祥超
55d70418cc 节点健康检查失败时增加节点名称和节点IP提示 2023-11-03 09:54:42 +08:00
刘祥超
7f5b070e36 优化商业版验证 2023-11-02 17:20:12 +08:00
刘祥超
993c7ee822 上传域名统计数据时限制域名长度不能超过64位 2023-11-02 17:19:56 +08:00
刘祥超
b5bb4e0df9 更新数据库 2023-10-30 19:04:23 +08:00
刘祥超
9f120fd0e0 访问日志存储策略增加“停止默认数据库存储”选项 2023-10-30 19:03:39 +08:00
刘祥超
77d614c9ea 实现网络数据包相关统计(商业版本) 2023-10-26 17:17:43 +08:00
刘祥超
531ec3c55d 优化域名解析文字提示 2023-10-17 15:54:08 +08:00
刘祥超
0d6c064194 将版本号修改为1.2.11 2023-10-17 13:49:39 +08:00
刘祥超
180e86c643 修复消息通知不能指定集群的Bug 2023-10-17 13:49:23 +08:00
刘祥超
86b04b2b6b 将临时的1.2.9.1升级程序版本号修改为1.2.10 2023-10-15 15:10:36 +08:00
刘祥超
7a5ec79ace 将版本号修改为1.2.10 2023-10-15 13:34:18 +08:00
刘祥超
7290ffd2cd 取消默认反向代理默认的50X重试 2023-10-15 09:40:39 +08:00
刘祥超
2f361c5bcc 优化消息任务相关代码 2023-10-15 09:39:46 +08:00
刘祥超
500d72aaf3 WAF记录IP动作中IP名单如果为空时,默认为全局黑名单 2023-10-15 09:34:20 +08:00
刘祥超
9fc391d1e8 删除不必要的代码 2023-10-14 18:15:54 +08:00
刘祥超
c86e3e2047 优化消息通知相关代码 2023-10-14 17:16:08 +08:00
刘祥超
7e72a90f53 优化消息发送相关代码/删除监控相关代码 2023-10-12 20:11:21 +08:00
刘祥超
7692fed38d 支持批量复制WAF设置 2023-10-09 19:52:51 +08:00
刘祥超
bdd7d2a181 申请证书任务列表区分管理员和用户 2023-10-09 16:18:32 +08:00
刘祥超
118c3f79e4 证书列表区分管理员和用户证书 2023-10-09 15:54:00 +08:00
刘祥超
804a33a002 访问日志列表搜索增加请求来源查询语法:referer:example.com 2023-10-08 17:52:53 +08:00
刘祥超
fe00588039 集群设置中增加“自动调节系统参数”选项 2023-10-08 15:08:28 +08:00
刘祥超
67aac200a7 修复常用网站、常用集群查询可能因为updatedAt过大导致的SQL错误 2023-09-22 16:41:44 +08:00
刘祥超
3e01ad4b68 节点配置中对父级节点进行排序,以保证查找的稳定性 2023-09-22 11:55:47 +08:00
刘祥超
b39690484e 将升级程序中的1.2.10改成1.2.9.1,方便在测试版本中也能升级 2023-09-18 17:02:54 +08:00
刘祥超
31a69ecb12 将全局设置的TCP相关设置移到“集群设置--网站设置”中 2023-09-18 16:55:45 +08:00
刘祥超
94b95beadf 将全局的通用设置--域名审核设置移到“集群设置--网站设置”中 2023-09-18 16:09:11 +08:00
刘祥超
6143f08cf2 IP名单删除任务完成后删除任务 2023-09-14 09:12:19 +08:00
刘祥超
73a5814fd6 版本号修改为1.2.9 2023-09-13 17:37:41 +08:00
刘祥超
448152d5c2 优化删除IP名单时操作 2023-09-13 17:16:00 +08:00
刘祥超
eedb3fb338 将节点版本号修改为1.2.9 2023-09-12 15:03:00 +08:00
刘祥超
06f6f68f3a 增加自动升级一处WAF规则 2023-09-12 14:59:07 +08:00
刘祥超
903e524e80 优化访客IP地址设置 2023-09-07 18:03:28 +08:00
刘祥超
fa6b4fcaee 套餐增加请求数(日/月)限制 2023-09-07 11:46:03 +08:00
刘祥超
67cc8e515f 修复一个测试用例 2023-09-06 18:19:25 +08:00
刘祥超
fa29817920 统计带宽计算增加最小样本数 2023-09-06 18:14:08 +08:00
刘祥超
794c3bc132 优化套餐升级程序 2023-09-06 18:01:41 +08:00
刘祥超
9e481d31ac 重新实现套餐相关功能 2023-09-06 16:30:47 +08:00
刘祥超
4ebc03af75 调用自定义HTTP DNS时增加action(值为GetDomains) 2023-08-28 16:28:08 +08:00
刘祥超
80e2face67 更新Agent IP库 2023-08-27 11:58:14 +08:00
刘祥超
815a5187d5 反向代理增加是否重试50X选项,默认为启用 2023-08-20 15:49:34 +08:00
刘祥超
1d7bc42fba 修复节点状态监控中磁盘空间可能为0的问题 2023-08-18 16:01:24 +08:00
刘祥超
1eb9cca793 将WAF策略中的默认省份封禁提示内容长度从255修改为65535 2023-08-14 12:54:11 +08:00
刘祥超
8766f5b1a9 修改版本号为1.2.8 2023-08-14 12:24:29 +08:00
刘祥超
823e42626d DNS任务增加失败重试 2023-08-13 15:26:59 +08:00
刘祥超
c5308cf41c 生成节点时去除停用的WAF规则集 2023-08-13 10:51:52 +08:00
刘祥超
3053157c6e 将节点的api.yaml改为api_node.yaml 2023-08-12 15:27:09 +08:00
刘祥超
d1ba141c65 优化错误处理相关代码 2023-08-11 16:13:33 +08:00
刘祥超
034ababead 静态分发增加例外URL、限制URL、排除隐藏文件等选项 2023-08-10 11:27:05 +08:00
刘祥超
f5450e37be WAF策略可以自定义默认的区域/省份封禁提示 2023-08-10 10:30:50 +08:00
刘祥超
549fca93e6 将版本号修改为1.2.7 2023-08-09 14:24:16 +08:00
刘祥超
efa0f33256 Update .golangci.yaml 2023-08-09 08:11:53 +08:00
刘祥超
977a12843c 添加golangci-lint配置 2023-08-08 18:36:24 +08:00
刘祥超
6de2834a8c 优化代码 2023-08-08 16:46:17 +08:00
刘祥超
51f91e1603 优化代码 2023-08-08 12:09:20 +08:00
刘祥超
d27b7c8fa1 允许用户调用获取缓存策略信息API 2023-08-07 19:55:57 +08:00
刘祥超
c5098c66af 缓存策略增加预热超时时间设置(默认20分钟) 2023-08-06 17:07:48 +08:00
刘祥超
c2635b0d04 修复默认WAF策略模板中分组不能默认关闭的问题 2023-08-02 17:15:26 +08:00
刘祥超
41a1a6a2e5 更新SQL 2023-08-02 17:02:39 +08:00
刘祥超
e437117e69 WAF策略增加“最多检查内容尺寸“选项 2023-08-02 16:59:38 +08:00
刘祥超
fdc8f78229 优化CC配置 2023-08-01 19:50:01 +08:00
刘祥超
2f78d76a1a 修复系统服务相关代码可能不执行的问题 2023-08-01 16:19:05 +08:00
刘祥超
742f2f0216 启动时自动创建相关软链接 2023-08-01 10:47:13 +08:00
刘祥超
89a606329f 修复自定义页面无法保存的问题 2023-07-31 09:46:00 +08:00
刘祥超
3bba79d14c 优化统计 2023-07-31 09:45:48 +08:00
刘祥超
9f9787e30f 版本号更改为1.2.6 2023-07-28 09:27:08 +08:00
刘祥超
529016d4d5 版本号更改为1.2.5 2023-07-26 15:30:37 +08:00
刘祥超
63942bfb08 将版本号修改为1.2.4 2023-07-26 10:19:02 +08:00
刘祥超
f4e4f32f9c 修复SysLocker无法写入新Key的问题 2023-07-26 10:18:52 +08:00
刘祥超
0a3c740502 版本号修改为1.2.3 2023-07-25 13:17:59 +08:00
刘祥超
9a3438e066 优化IP名单使用IP搜索查询速度 2023-07-25 12:26:12 +08:00
刘祥超
814b82e1b6 优化TOA相关代码 2023-07-24 15:33:44 +08:00
刘祥超
89cfd175cd 优化TOA相关API 2023-07-24 09:56:43 +08:00
刘祥超
860816719e 单个节点所在多个集群共用一个缓存策略时只加载其中一个 2023-07-20 16:54:34 +08:00
刘祥超
caa936f0ac 大幅提升SysLocker自增性能 2023-07-20 14:25:42 +08:00
刘祥超
97836a89eb 优化代码 2023-07-19 18:49:23 +08:00
234 changed files with 28580 additions and 4278 deletions

75
.golangci.yaml Normal file
View File

@@ -0,0 +1,75 @@
# https://golangci-lint.run/usage/configuration/
linters:
enable-all: true
disable:
- ifshort
- exhaustivestruct
- golint
- nosnakecase
- scopelint
- varcheck
- structcheck
- interfacer
- maligned
- deadcode
- dogsled
- wrapcheck
- wastedassign
- varnamelen
- testpackage
- thelper
- nilerr
- sqlclosecheck
- paralleltest
- nonamedreturns
- nlreturn
- nakedret
- ireturn
- interfacebloat
- gosmopolitan
- gomnd
- goerr113
- gochecknoglobals
- exhaustruct
- errorlint
- depguard
- exhaustive
- containedctx
- wsl
- cyclop
- dupword
- errchkjson
- contextcheck
- tagalign
- dupl
- forbidigo
- funlen
- goconst
- godox
- gosec
- lll
- nestif
- revive
- unparam
- stylecheck
- gocritic
- gofumpt
- gomoddirectives
- godot
- gofmt
- gocognit
- mirror
- gocyclo
- gochecknoinits
- gci
- maintidx
- prealloc
- goimports
- errname
- musttag
- forcetypeassert
- whitespace
- noctx
- tagliatelle
- nilnil

View File

@@ -115,7 +115,11 @@ function build() {
fi fi
# building api node # building api node
env GOOS="$OS" GOARCH="$ARCH" go build -trimpath -tags $TAG --ldflags="-s -w" -o "$DIST"/bin/edge-api "$ROOT"/../cmd/edge-api/main.go env GOOS="$OS" GOARCH="$ARCH" go build -trimpath -tags $TAG --ldflags="-s -w" -o "$DIST/bin/$NAME" "$ROOT"/../cmd/edge-api/main.go
if [ ! -f "${DIST}/bin/${NAME}" ]; then
echo "build failed!"
exit
fi
# delete hidden files # delete hidden files
find "$DIST" -name ".DS_Store" -delete find "$DIST" -name ".DS_Store" -delete

View File

@@ -12,5 +12,5 @@ dbs:
fields: fields:
bool: [ "uamIsOn", "followPort", "requestHostExcludingPort", "autoRemoteStart", "autoInstallNftables", "enableIPLists", "detectAgents", "checkingPorts", "enableRecordHealthCheck", "offlineIsNotified", "http2Enabled", "http3Enabled" ] bool: [ "uamIsOn", "followPort", "requestHostExcludingPort", "autoRemoteStart", "autoInstallNftables", "enableIPLists", "detectAgents", "checkingPorts", "enableRecordHealthCheck", "offlineIsNotified", "http2Enabled", "http3Enabled", "enableHTTP2", "retry50X", "retry40X", "autoSystemTuning", "disableDefaultDB", "autoTrimDisks" ]

View File

@@ -1,3 +1,7 @@
#!/usr/bin/env bash #!/usr/bin/env bash
go run `dirname $0`/../cmd/sql-dump/main.go -dir=`dirname $0` # generate 'internal/setup/sql.json' file
CWD="$(dirname "$0")"
go run "${CWD}"/../cmd/sql-dump/main.go -dir="${CWD}"

View File

@@ -0,0 +1,2 @@
edge-instance-installer*
prepare.sh

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env bash
function build() {
ROOT=$(dirname "$0")
OS="${1}"
ARCH="${2}"
TAG="${3}"
if [ -z "$OS" ]; then
echo "usage: build.sh OS ARCH"
exit
fi
if [ -z "$ARCH" ]; then
echo "usage: build.sh OS ARCH"
exit
fi
VERSION=$(lookup_version "${ROOT}/../../internal/const/const.go")
TARGET_NAME="edge-instance-installer-${OS}-${ARCH}-v${VERSION}"
env GOOS=linux GOARCH="${ARCH}" go build -tags="${TAG}" -trimpath -ldflags="-s -w" -o "${TARGET_NAME}" main.go
if [ -f "${TARGET_NAME}" ]; then
cp "${TARGET_NAME}" "${ROOT}/../../../EdgeAdmin/docker/instance/edge-instance/assets"
fi
echo "[done]"
}
function lookup_version() {
FILE=$1
VERSION_DATA=$(cat "$FILE")
re="Version[ ]+=[ ]+\"([0-9.]+)\""
if [[ $VERSION_DATA =~ $re ]]; then
VERSION=${BASH_REMATCH[1]}
echo "$VERSION"
else
echo "could not match version"
exit
fi
}
build "$1" "$2" "$3"

View File

@@ -0,0 +1,97 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package main
import (
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/instances"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/lists"
"log"
"os"
)
func main() {
var verbose = lists.ContainsString(os.Args, "-v")
var dbHost = "127.0.0.1"
var dbPassword = "123456"
var dbName = "edges"
envDBHost, _ := os.LookupEnv("EDGE_DB_HOST")
if len(envDBHost) > 0 {
dbHost = envDBHost
if verbose {
log.Println("env EDGE_DB_HOST=" + envDBHost)
}
}
envDBPassword, _ := os.LookupEnv("EDGE_DB_PASSWORD")
if len(envDBPassword) > 0 {
dbPassword = envDBPassword
if verbose {
log.Println("env EDGE_DB_PASSWORD=" + envDBPassword)
}
}
envDBName, _ := os.LookupEnv("EDGE_DB_NAME")
if len(envDBName) > 0 {
dbName = envDBName
if verbose {
log.Println("env EDGE_DB_NAME=" + envDBName)
}
}
var isTesting = lists.ContainsString(os.Args, "-test") || lists.ContainsString(os.Args, "--test")
if isTesting {
fmt.Println("testing mode ...")
}
var instance = instances.NewInstance(instances.Options{
IsTesting: isTesting,
Verbose: verbose,
Cacheable: false,
WorkDir: "",
SrcDir: "/usr/local/goedge/src",
DB: struct {
Host string
Port int
Username string
Password string
Name string
}{
Host: dbHost,
Port: 3306,
Username: "root",
Password: dbPassword,
Name: dbName,
},
AdminNode: struct {
Port int
}{
Port: 7788,
},
APINode: struct {
HTTPPort int
RestHTTPPort int
}{
HTTPPort: 8001,
RestHTTPPort: 8002,
},
Node: struct{ HTTPPort int }{
HTTPPort: 80,
},
UserNode: struct {
HTTPPort int
}{
HTTPPort: 7799,
},
})
err := instance.SetupAll()
if err != nil {
fmt.Println("[ERROR]setup failed: " + err.Error())
return
}
fmt.Println("ok")
}

45
go.mod
View File

@@ -1,57 +1,70 @@
module github.com/TeaOSLab/EdgeAPI module github.com/TeaOSLab/EdgeAPI
go 1.18 go 1.21
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
require ( require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.8.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.4.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.1.0
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
github.com/aliyun/alibaba-cloud-sdk-go v1.61.1755 github.com/aliyun/alibaba-cloud-sdk-go v1.62.587
github.com/andybalholm/brotli v1.0.4 github.com/andybalholm/brotli v1.0.4
github.com/aws/aws-sdk-go v1.40.45
github.com/cespare/xxhash v1.1.0 github.com/cespare/xxhash v1.1.0
github.com/cespare/xxhash/v2 v2.1.1
github.com/fsnotify/fsnotify v1.6.0 github.com/fsnotify/fsnotify v1.6.0
github.com/go-acme/lego/v4 v4.10.2 github.com/go-acme/lego/v4 v4.10.2
github.com/go-sql-driver/mysql v1.7.0 github.com/go-sql-driver/mysql v1.7.0
github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible
github.com/iwind/TeaGo v0.0.0-20230704135818-4a5646ab1f5b github.com/iwind/TeaGo v0.0.0-20240312020455-6f20b5121caf
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62 github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62
github.com/miekg/dns v1.1.50 github.com/miekg/dns v1.1.50
github.com/mozillazg/go-pinyin v0.18.0 github.com/mozillazg/go-pinyin v0.18.0
github.com/pkg/sftp v1.12.0 github.com/pkg/sftp v1.12.0
github.com/shirou/gopsutil/v3 v3.22.2 github.com/shirou/gopsutil/v3 v3.22.2
github.com/smartwalle/alipay/v3 v3.1.7 github.com/smartwalle/alipay/v3 v3.2.20
golang.org/x/crypto v0.5.0 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.801
golang.org/x/net v0.8.0 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.801
golang.org/x/sys v0.8.0 github.com/volcengine/volc-sdk-golang v1.0.124
google.golang.org/grpc v1.45.0 golang.org/x/crypto v0.22.0
golang.org/x/net v0.24.0
golang.org/x/sys v0.19.0
google.golang.org/grpc v1.62.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/cenkalti/backoff/v4 v4.2.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang-jwt/jwt/v5 v5.0.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/kr/fs v0.1.0 // indirect github.com/kr/fs v0.1.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/smartwalle/crypto4go v1.0.2 // indirect github.com/smartwalle/ncrypto v1.0.4 // indirect
github.com/tdewolff/minify/v2 v2.12.7 // indirect github.com/smartwalle/ngx v1.0.9 // indirect
github.com/tdewolff/parse/v2 v2.6.6 // indirect github.com/smartwalle/nsign v1.0.9 // indirect
github.com/technoweenie/multipartstreamer v1.0.1 // indirect github.com/technoweenie/multipartstreamer v1.0.1 // indirect
github.com/tklauser/go-sysconf v0.3.9 // indirect github.com/tklauser/go-sysconf v0.3.9 // indirect
github.com/tklauser/numcpus v0.3.0 // indirect github.com/tklauser/numcpus v0.3.0 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.6.0 // indirect
google.golang.org/genproto v0.0.0-20220317150908-0efb43f6373e // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240228224816-df926f6c8641 // indirect
google.golang.org/protobuf v1.28.0 // indirect google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.66.6 // indirect gopkg.in/ini.v1 v1.66.6 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
) )

768
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -130,6 +130,9 @@ func TestGenerate_EAB(t *testing.T) {
} else { } else {
reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) reg, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
} }
if err != nil {
t.Fatal(err)
}
myUser.Registration = reg myUser.Registration = reg
request := certificate.ObtainRequest{ request := certificate.ObtainRequest{

View File

@@ -1,6 +1,7 @@
package acme package acme
import ( import (
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
@@ -45,7 +46,7 @@ func (this *DNSProvider) Present(domain, token, keyAuth string) error {
if !wasDeleted { if !wasDeleted {
records, err := this.raw.QueryRecords(this.dnsDomain, recordName, dnstypes.RecordTypeTXT) records, err := this.raw.QueryRecords(this.dnsDomain, recordName, dnstypes.RecordTypeTXT)
if err != nil { if err != nil {
return errors.New("query DNS record failed: " + err.Error()) return fmt.Errorf("query DNS record failed: %w", err)
} }
for _, record := range records { for _, record := range records {
err = this.raw.DeleteRecord(this.dnsDomain, record) err = this.raw.DeleteRecord(this.dnsDomain, record)
@@ -67,7 +68,7 @@ func (this *DNSProvider) Present(domain, token, keyAuth string) error {
Route: this.raw.DefaultRoute(), Route: this.raw.DefaultRoute(),
}) })
if err != nil { if err != nil {
return errors.New("create DNS record failed: " + err.Error()) return fmt.Errorf("create DNS record failed: %w", err)
} }
return nil return nil

View File

@@ -1,6 +1,7 @@
package acme package acme
import ( import (
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
@@ -92,26 +93,26 @@ func (this *Request) runDNS() (certData []byte, keyData []byte, err error) {
// 注册用户 // 注册用户
var resource = this.task.User.GetRegistration() var resource = this.task.User.GetRegistration()
if resource != nil { if resource != nil {
resource, err = client.Registration.QueryRegistration() _, err = client.Registration.QueryRegistration()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} else { } else {
if this.task.Provider.RequireEAB { if this.task.Provider.RequireEAB {
resource, err := client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{ resource, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
TermsOfServiceAgreed: true, TermsOfServiceAgreed: true,
Kid: this.task.Account.EABKid, Kid: this.task.Account.EABKid,
HmacEncoded: this.task.Account.EABKey, HmacEncoded: this.task.Account.EABKey,
}) })
if err != nil { if err != nil {
return nil, nil, errors.New("register user failed: " + err.Error()) return nil, nil, fmt.Errorf("register user failed: %w", err)
} }
err = this.task.User.Register(resource) err = this.task.User.Register(resource)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} else { } else {
resource, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) resource, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -134,7 +135,7 @@ func (this *Request) runDNS() (certData []byte, keyData []byte, err error) {
} }
certResource, err := client.Certificate.Obtain(request) certResource, err := client.Certificate.Obtain(request)
if err != nil { if err != nil {
return nil, nil, errors.New("obtain cert failed: " + err.Error()) return nil, nil, fmt.Errorf("obtain cert failed: %w", err)
} }
return certResource.Certificate, certResource.PrivateKey, nil return certResource.Certificate, certResource.PrivateKey, nil
@@ -165,26 +166,26 @@ func (this *Request) runHTTP() (certData []byte, keyData []byte, err error) {
// 注册用户 // 注册用户
var resource = this.task.User.GetRegistration() var resource = this.task.User.GetRegistration()
if resource != nil { if resource != nil {
resource, err = client.Registration.QueryRegistration() _, err = client.Registration.QueryRegistration()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} else { } else {
if this.task.Provider.RequireEAB { if this.task.Provider.RequireEAB {
resource, err := client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{ resource, err = client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
TermsOfServiceAgreed: true, TermsOfServiceAgreed: true,
Kid: this.task.Account.EABKid, Kid: this.task.Account.EABKid,
HmacEncoded: this.task.Account.EABKey, HmacEncoded: this.task.Account.EABKey,
}) })
if err != nil { if err != nil {
return nil, nil, errors.New("register user failed: " + err.Error()) return nil, nil, fmt.Errorf("register user failed: %w", err)
} }
err = this.task.User.Register(resource) err = this.task.User.Register(resource)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} else { } else {
resource, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) resource, err = client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@@ -1,6 +1,7 @@
package apps package apps
import ( import (
"errors"
"fmt" "fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
@@ -9,8 +10,10 @@ import (
"github.com/iwind/gosock/pkg/gosock" "github.com/iwind/gosock/pkg/gosock"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"time" "time"
) )
@@ -184,13 +187,16 @@ func (this *AppCmd) runStart() {
return return
} }
cmd := exec.Command(os.Args[0]) var cmd = exec.Command(this.exe())
err := cmd.Start() err := cmd.Start()
if err != nil { if err != nil {
fmt.Println(this.product+" start failed:", err.Error()) fmt.Println(this.product+" start failed:", err.Error())
return return
} }
// create symbolic links
_ = this.createSymLinks()
fmt.Println(this.product+" started ok, pid:", cmd.Process.Pid) fmt.Println(this.product+" started ok, pid:", cmd.Process.Pid)
} }
@@ -237,3 +243,58 @@ func (this *AppCmd) getPID() int {
} }
return maps.NewMap(reply.Params).GetInt("pid") return maps.NewMap(reply.Params).GetInt("pid")
} }
func (this *AppCmd) exe() string {
var exe, _ = os.Executable()
if len(exe) == 0 {
exe = os.Args[0]
}
return exe
}
// 创建软链接
func (this *AppCmd) createSymLinks() error {
if runtime.GOOS != "linux" {
return nil
}
var exe, _ = os.Executable()
if len(exe) == 0 {
return nil
}
var errorList = []string{}
// bin
{
var target = "/usr/bin/" + teaconst.ProcessName
old, _ := filepath.EvalSymlinks(target)
if old != exe {
_ = os.Remove(target)
err := os.Symlink(exe, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
// log
{
var realPath = filepath.Dir(filepath.Dir(exe)) + "/logs/run.log"
var target = "/var/log/" + teaconst.ProcessName + ".log"
old, _ := filepath.EvalSymlinks(target)
if old != realPath {
_ = os.Remove(target)
err := os.Symlink(realPath, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
if len(errorList) > 0 {
return errors.New(strings.Join(errorList, "\n"))
}
return nil
}

View File

@@ -1,12 +1,14 @@
package teaconst package teaconst
const ( const (
Version = "1.2.2" Version = "1.3.7"
ProductName = "Edge API" ProductName = "Edge API"
ProcessName = "edge-api" ProcessName = "edge-api"
ProductNameZH = "Edge" ProductNameZH = "Edge"
GlobalProductName = "GoEdge"
Role = "api" Role = "api"
EncryptKey = "8f983f4d69b83aaa0d74b21a212f6967" EncryptKey = "8f983f4d69b83aaa0d74b21a212f6967"
@@ -18,8 +20,5 @@ const (
// 其他节点版本号,用来检测是否有需要升级的节点 // 其他节点版本号,用来检测是否有需要升级的节点
NodeVersion = "1.2.2" NodeVersion = "1.3.7"
// SQLVersion SQL版本号
SQLVersion = "11"
) )

View File

@@ -0,0 +1,9 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package teaconst
const (
// DefaultMaxNodes 节点数限制
DefaultMaxNodes int32 = 50
)

View File

@@ -2,6 +2,7 @@ package acme
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
acmeutils "github.com/TeaOSLab/EdgeAPI/internal/acme" acmeutils "github.com/TeaOSLab/EdgeAPI/internal/acme"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
@@ -106,9 +107,17 @@ func (this *ACMETaskDAO) DisableAllTasksWithCertId(tx *dbs.Tx, certId int64) err
} }
// CountAllEnabledACMETasks 计算所有任务数量 // CountAllEnabledACMETasks 计算所有任务数量
func (this *ACMETaskDAO) CountAllEnabledACMETasks(tx *dbs.Tx, userId int64, isAvailable bool, isExpired bool, expiringDays int64, keyword string) (int64, error) { func (this *ACMETaskDAO) CountAllEnabledACMETasks(tx *dbs.Tx, userId int64, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userOnly bool) (int64, error) {
var query = this.Query(tx) var query = this.Query(tx)
query.Attr("userId", userId) // 这个条件必须加上 if userId > 0 {
query.Attr("userId", userId)
} else {
if userOnly {
query.Gt("userId", 0)
} else {
query.Attr("userId", 0)
}
}
if isAvailable || isExpired || expiringDays > 0 { if isAvailable || isExpired || expiringDays > 0 {
query.Gt("certId", 0) query.Gt("certId", 0)
@@ -138,9 +147,17 @@ func (this *ACMETaskDAO) CountAllEnabledACMETasks(tx *dbs.Tx, userId int64, isAv
} }
// ListEnabledACMETasks 列出单页任务 // ListEnabledACMETasks 列出单页任务
func (this *ACMETaskDAO) ListEnabledACMETasks(tx *dbs.Tx, userId int64, isAvailable bool, isExpired bool, expiringDays int64, keyword string, offset int64, size int64) (result []*ACMETask, err error) { func (this *ACMETaskDAO) ListEnabledACMETasks(tx *dbs.Tx, userId int64, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userOnly bool, offset int64, size int64) (result []*ACMETask, err error) {
var query = this.Query(tx) var query = this.Query(tx)
query.Attr("userId", userId) // 这个条件必须加上 if userId > 0 {
query.Attr("userId", userId)
} else {
if userOnly {
query.Gt("userId", 0)
} else {
query.Attr("userId", 0)
}
}
if isAvailable || isExpired || expiringDays > 0 { if isAvailable || isExpired || expiringDays > 0 {
query.Gt("certId", 0) query.Gt("certId", 0)
@@ -228,8 +245,8 @@ func (this *ACMETaskDAO) UpdateACMETask(tx *dbs.Tx, acmeTaskId int64, acmeUserId
return err return err
} }
// CheckACMETask 检查权限 // CheckUserACMETask 检查用户权限
func (this *ACMETaskDAO) CheckACMETask(tx *dbs.Tx, userId int64, acmeTaskId int64) (bool, error) { func (this *ACMETaskDAO) CheckUserACMETask(tx *dbs.Tx, userId int64, acmeTaskId int64) (bool, error) {
var query = this.Query(tx) var query = this.Query(tx)
if userId > 0 { if userId > 0 {
query.Attr("userId", userId) query.Attr("userId", userId)
@@ -241,6 +258,15 @@ func (this *ACMETaskDAO) CheckACMETask(tx *dbs.Tx, userId int64, acmeTaskId int6
Exist() Exist()
} }
// FindACMETaskUserId 查找任务所属用户ID
func (this *ACMETaskDAO) FindACMETaskUserId(tx *dbs.Tx, taskId int64) (userId int64, err error) {
return this.Query(tx).
Pk(taskId).
Result("userId").
FindInt64Col(0)
}
// UpdateACMETaskCert 设置任务关联的证书 // UpdateACMETaskCert 设置任务关联的证书
func (this *ACMETaskDAO) UpdateACMETaskCert(tx *dbs.Tx, taskId int64, certId int64) error { func (this *ACMETaskDAO) UpdateACMETaskCert(tx *dbs.Tx, taskId int64, certId int64) error {
if taskId <= 0 { if taskId <= 0 {
@@ -434,7 +460,7 @@ func (this *ACMETaskDAO) runTaskWithoutLog(tx *dbs.Tx, taskId int64) (isOk bool,
CertData: certData, CertData: certData,
KeyData: keyData, KeyData: keyData,
} }
err = sslConfig.Init(nil) err = sslConfig.Init(context.Background())
if err != nil { if err != nil {
errMsg = "证书生成成功,但是分析证书信息时发生错误:" + err.Error() errMsg = "证书生成成功,但是分析证书信息时发生错误:" + err.Error()
return return

View File

@@ -1,6 +1,7 @@
package acme package acme
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
@@ -31,7 +32,7 @@ func init() {
func (this *ACMETaskLogDAO) CreateACMETaskLog(tx *dbs.Tx, taskId int64, isOk bool, errMsg string) error { func (this *ACMETaskLogDAO) CreateACMETaskLog(tx *dbs.Tx, taskId int64, isOk bool, errMsg string) error {
var op = NewACMETaskLogOperator() var op = NewACMETaskLogOperator()
op.TaskId = taskId op.TaskId = taskId
op.Error = errMsg op.Error = utils.LimitString(errMsg, 1024)
op.IsOk = isOk op.IsOk = isOk
err := this.Save(tx, op) err := this.Save(tx, op)
return err return err

View File

@@ -130,6 +130,19 @@ func (this *AdminDAO) FindAdminIdWithUsername(tx *dbs.Tx, username string) (int6
return int64(one.(*Admin).Id), nil return int64(one.(*Admin).Id), nil
} }
// FindAdminWithUsername 根据用户名查询管理员信息
func (this *AdminDAO) FindAdminWithUsername(tx *dbs.Tx, username string) (*Admin, error) {
one, err := this.Query(tx).
Attr("username", username).
State(AdminStateEnabled).
ResultPk().
Find()
if err != nil || one == nil {
return nil, err
}
return one.(*Admin), nil
}
// UpdateAdminPassword 更改管理员密码 // UpdateAdminPassword 更改管理员密码
func (this *AdminDAO) UpdateAdminPassword(tx *dbs.Tx, adminId int64, password string) error { func (this *AdminDAO) UpdateAdminPassword(tx *dbs.Tx, adminId int64, password string) error {
if adminId <= 0 { if adminId <= 0 {
@@ -212,7 +225,7 @@ func (this *AdminDAO) UpdateAdmin(tx *dbs.Tx, adminId int64, username string, ca
return nil return nil
} }
// CheckAdminUsername 检查用户名是否存在 // CheckAdminUsername 检查管理员用户名是否存在
func (this *AdminDAO) CheckAdminUsername(tx *dbs.Tx, adminId int64, username string) (bool, error) { func (this *AdminDAO) CheckAdminUsername(tx *dbs.Tx, adminId int64, username string) (bool, error) {
query := this.Query(tx). query := this.Query(tx).
State(AdminStateEnabled). State(AdminStateEnabled).
@@ -260,7 +273,7 @@ func (this *AdminDAO) FindAllAdminModules(tx *dbs.Tx) (result []*Admin, err erro
_, err = this.Query(tx). _, err = this.Query(tx).
State(AdminStateEnabled). State(AdminStateEnabled).
Attr("isOn", true). Attr("isOn", true).
Result("id", "modules", "isSuper", "fullname", "theme"). Result("id", "modules", "isSuper", "fullname", "theme", "lang").
Slice(&result). Slice(&result).
FindAll() FindAll()
return return
@@ -313,6 +326,14 @@ func (this *AdminDAO) UpdateAdminTheme(tx *dbs.Tx, adminId int64, theme string)
UpdateQuickly() UpdateQuickly()
} }
// UpdateAdminLang 设置管理员语言
func (this *AdminDAO) UpdateAdminLang(tx *dbs.Tx, adminId int64, langCode string) error {
return this.Query(tx).
Pk(adminId).
Set("lang", langCode).
UpdateQuickly()
}
// CheckSuperAdmin 检查管理员是否为超级管理员 // CheckSuperAdmin 检查管理员是否为超级管理员
func (this *AdminDAO) CheckSuperAdmin(tx *dbs.Tx, adminId int64) (bool, error) { func (this *AdminDAO) CheckSuperAdmin(tx *dbs.Tx, adminId int64) (bool, error) {
if adminId <= 0 { if adminId <= 0 {

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
@@ -37,7 +38,7 @@ func (this *APINode) DecodeHTTPS(tx *dbs.Tx, cacheMap *utils.CacheMap) (*serverc
return nil, err return nil, err
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -55,7 +56,7 @@ func (this *APINode) DecodeHTTPS(tx *dbs.Tx, cacheMap *utils.CacheMap) (*serverc
} }
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -135,7 +136,7 @@ func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx, cacheMap *utils.CacheMap) (*ser
return nil, err return nil, err
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -153,7 +154,7 @@ func (this *APINode) DecodeRestHTTPS(tx *dbs.Tx, cacheMap *utils.CacheMap) (*ser
} }
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,6 @@
package authority_test
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -2,6 +2,18 @@ package authority
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
AuthorityKeyField_Id dbs.FieldName = "id" // ID
AuthorityKeyField_Value dbs.FieldName = "value" // Key值
AuthorityKeyField_DayFrom dbs.FieldName = "dayFrom" // 开始日期
AuthorityKeyField_DayTo dbs.FieldName = "dayTo" // 结束日期
AuthorityKeyField_Hostname dbs.FieldName = "hostname" // Hostname
AuthorityKeyField_MacAddresses dbs.FieldName = "macAddresses" // MAC地址
AuthorityKeyField_UpdatedAt dbs.FieldName = "updatedAt" // 创建/修改时间
AuthorityKeyField_Company dbs.FieldName = "company" // 公司组织
AuthorityKeyField_RequestCode dbs.FieldName = "requestCode" // 申请码
)
// AuthorityKey 企业版认证信息 // AuthorityKey 企业版认证信息
type AuthorityKey struct { type AuthorityKey struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
@@ -12,17 +24,19 @@ type AuthorityKey struct {
MacAddresses dbs.JSON `field:"macAddresses"` // MAC地址 MacAddresses dbs.JSON `field:"macAddresses"` // MAC地址
UpdatedAt uint64 `field:"updatedAt"` // 创建/修改时间 UpdatedAt uint64 `field:"updatedAt"` // 创建/修改时间
Company string `field:"company"` // 公司组织 Company string `field:"company"` // 公司组织
RequestCode string `field:"requestCode"` // 申请码
} }
type AuthorityKeyOperator struct { type AuthorityKeyOperator struct {
Id interface{} // ID Id any // ID
Value interface{} // Key值 Value any // Key值
DayFrom interface{} // 开始日期 DayFrom any // 开始日期
DayTo interface{} // 结束日期 DayTo any // 结束日期
Hostname interface{} // Hostname Hostname any // Hostname
MacAddresses interface{} // MAC地址 MacAddresses any // MAC地址
UpdatedAt interface{} // 创建/修改时间 UpdatedAt any // 创建/修改时间
Company interface{} // 公司组织 Company any // 公司组织
RequestCode any // 申请码
} }
func NewAuthorityKeyOperator() *AuthorityKeyOperator { func NewAuthorityKeyOperator() *AuthorityKeyOperator {

View File

@@ -61,11 +61,12 @@ func (this *DNSTaskDAO) CreateDNSTask(tx *dbs.Tx, clusterId int64, serverId int6
"error": "", "error": "",
"version": time.Now().UnixNano(), "version": time.Now().UnixNano(),
}, maps.Map{ }, maps.Map{
"updatedAt": time.Now().Unix(), "updatedAt": time.Now().Unix(),
"isDone": false, "isDone": false,
"isOk": false, "isOk": false,
"error": "", "error": "",
"version": time.Now().UnixNano(), "version": time.Now().UnixNano(),
"countFails": 0,
}) })
if err != nil { if err != nil {
return err return err
@@ -108,7 +109,7 @@ func (this *DNSTaskDAO) CreateDomainTask(tx *dbs.Tx, domainId int64, taskType DN
// FindAllDoingTasks 查找所有正在执行的任务 // FindAllDoingTasks 查找所有正在执行的任务
func (this *DNSTaskDAO) FindAllDoingTasks(tx *dbs.Tx) (result []*DNSTask, err error) { func (this *DNSTaskDAO) FindAllDoingTasks(tx *dbs.Tx) (result []*DNSTask, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Attr("isDone", 0). Where("(isDone=0 OR (isDone=1 AND isOk=0 AND countFails<3))"). // 3 = retry times
Asc("version"). Asc("version").
AscPk(). AscPk().
Slice(&result). Slice(&result).
@@ -171,6 +172,7 @@ func (this *DNSTaskDAO) UpdateDNSTaskError(tx *dbs.Tx, taskId int64, err string)
op.IsDone = true op.IsDone = true
op.Error = err op.Error = err
op.IsOk = false op.IsOk = false
op.CountFails = dbs.SQL("countFails+1")
return this.Save(tx, op) return this.Save(tx, op)
} }
@@ -197,6 +199,7 @@ func (this *DNSTaskDAO) UpdateDNSTaskDone(tx *dbs.Tx, taskId int64, taskVersion
op.Id = taskId op.Id = taskId
op.IsDone = true op.IsDone = true
op.IsOk = true op.IsOk = true
op.CountFails = 0
op.Error = "" op.Error = ""
return this.Save(tx, op) return this.Save(tx, op)
} }
@@ -219,6 +222,7 @@ func (this *DNSTaskDAO) UpdateClusterDNSTasksDone(tx *dbs.Tx, clusterId int64, m
Set("isDone", true). Set("isDone", true).
Set("isOk", true). Set("isOk", true).
Set("error", ""). Set("error", "").
Set("countFails", 0).
UpdateQuickly() UpdateQuickly()
} }

View File

@@ -1,5 +1,23 @@
package dns package dns
import "github.com/iwind/TeaGo/dbs"
const (
DNSTaskField_Id dbs.FieldName = "id" // ID
DNSTaskField_ClusterId dbs.FieldName = "clusterId" // 集群ID
DNSTaskField_ServerId dbs.FieldName = "serverId" // 服务ID
DNSTaskField_NodeId dbs.FieldName = "nodeId" // 节点ID
DNSTaskField_DomainId dbs.FieldName = "domainId" // 域名ID
DNSTaskField_RecordName dbs.FieldName = "recordName" // 记录名
DNSTaskField_Type dbs.FieldName = "type" // 任务类型
DNSTaskField_UpdatedAt dbs.FieldName = "updatedAt" // 更新时间
DNSTaskField_IsDone dbs.FieldName = "isDone" // 是否已完成
DNSTaskField_IsOk dbs.FieldName = "isOk" // 是否成功
DNSTaskField_Error dbs.FieldName = "error" // 错误信息
DNSTaskField_Version dbs.FieldName = "version" // 版本
DNSTaskField_CountFails dbs.FieldName = "countFails" // 尝试失败次数
)
// DNSTask DNS更新任务 // DNSTask DNS更新任务
type DNSTask struct { type DNSTask struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
@@ -14,6 +32,7 @@ type DNSTask struct {
IsOk bool `field:"isOk"` // 是否成功 IsOk bool `field:"isOk"` // 是否成功
Error string `field:"error"` // 错误信息 Error string `field:"error"` // 错误信息
Version uint64 `field:"version"` // 版本 Version uint64 `field:"version"` // 版本
CountFails uint32 `field:"countFails"` // 尝试失败次数
} }
type DNSTaskOperator struct { type DNSTaskOperator struct {
@@ -29,6 +48,7 @@ type DNSTaskOperator struct {
IsOk any // 是否成功 IsOk any // 是否成功
Error any // 错误信息 Error any // 错误信息
Version any // 版本 Version any // 版本
CountFails any // 尝试失败次数
} }
func NewDNSTaskOperator() *DNSTaskOperator { func NewDNSTaskOperator() *DNSTaskOperator {

View File

@@ -3,6 +3,7 @@
package dnsutils package dnsutils
import ( import (
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
@@ -217,7 +218,7 @@ func FindDefaultDomainRoute(tx *dbs.Tx, domain *dns.DNSDomain) (string, error) {
} }
paramsMap, err := provider.DecodeAPIParams() paramsMap, err := provider.DecodeAPIParams()
if err != nil { if err != nil {
return "", errors.New("decode provider params failed: " + err.Error()) return "", fmt.Errorf("decode provider params failed: %w", err)
} }
var dnsProvider = dnsclients.FindProvider(provider.Type, int64(provider.Id)) var dnsProvider = dnsclients.FindProvider(provider.Type, int64(provider.Id))
if dnsProvider == nil { if dnsProvider == nil {

View File

@@ -7,8 +7,8 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman" "github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/zero" "github.com/TeaOSLab/EdgeAPI/internal/zero"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
@@ -232,7 +232,7 @@ Loop:
// CreateHTTPAccessLog 写入单条访问日志 // CreateHTTPAccessLog 写入单条访问日志
func (this *HTTPAccessLogDAO) CreateHTTPAccessLog(tx *dbs.Tx, dao *HTTPAccessLogDAO, accessLog *pb.HTTPAccessLog) error { func (this *HTTPAccessLogDAO) CreateHTTPAccessLog(tx *dbs.Tx, dao *HTTPAccessLogDAO, accessLog *pb.HTTPAccessLog) error {
var day = "" var day string
// 注意:如果你修改了 TimeISO8601 的逻辑,这里也需要同步修改 // 注意:如果你修改了 TimeISO8601 的逻辑,这里也需要同步修改
if len(accessLog.TimeISO8601) > 10 { if len(accessLog.TimeISO8601) > 10 {
day = strings.ReplaceAll(accessLog.TimeISO8601[:10], "-", "") day = strings.ReplaceAll(accessLog.TimeISO8601[:10], "-", "")
@@ -245,7 +245,7 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLog(tx *dbs.Tx, dao *HTTPAccessLog
return err return err
} }
fields := map[string]interface{}{} var fields = map[string]any{}
fields["serverId"] = accessLog.ServerId fields["serverId"] = accessLog.ServerId
fields["nodeId"] = accessLog.NodeId fields["nodeId"] = accessLog.NodeId
fields["status"] = accessLog.Status fields["status"] = accessLog.Status
@@ -265,7 +265,11 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLog(tx *dbs.Tx, dao *HTTPAccessLog
fields["remoteAddr"] = accessLog.RemoteAddr fields["remoteAddr"] = accessLog.RemoteAddr
} }
if tableDef.HasDomain { if tableDef.HasDomain {
fields["domain"] = accessLog.Host if len(accessLog.Host) > 128 {
fields["domain"] = accessLog.Host[:128]
} else {
fields["domain"] = accessLog.Host
}
} }
content, err := json.Marshal(accessLog) content, err := json.Marshal(accessLog)
@@ -461,6 +465,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx,
var protoReg = regexp.MustCompile(`proto:(\S+)`) var protoReg = regexp.MustCompile(`proto:(\S+)`)
var schemeReg = regexp.MustCompile(`scheme:(\S+)`) var schemeReg = regexp.MustCompile(`scheme:(\S+)`)
var methodReg = regexp.MustCompile(`(?:method|requestMethod):(\S+)`) var methodReg = regexp.MustCompile(`(?:method|requestMethod):(\S+)`)
var refererReg = regexp.MustCompile(`referer:(\S+)`)
var count = len(tableQueries) var count = len(tableQueries)
var wg = &sync.WaitGroup{} var wg = &sync.WaitGroup{}
@@ -515,14 +520,14 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx,
// keyword // keyword
if len(ip) > 0 { if len(ip) > 0 {
// TODO 支持IP范围 // TODO 支持IPv6范围
if tableQuery.hasRemoteAddrField { if tableQuery.hasRemoteAddrField {
// IP格式 // IP格式
if strings.Contains(ip, ",") || strings.Contains(ip, "-") { if strings.Contains(ip, ",") || strings.Contains(ip, "-") {
rangeConfig, err := shared.ParseIPRange(ip) rangeConfig, err := shared.ParseIPRange(ip)
if err == nil { if err == nil {
if len(rangeConfig.IPFrom) > 0 && len(rangeConfig.IPTo) > 0 { if len(rangeConfig.IPFrom) > 0 && len(rangeConfig.IPTo) > 0 {
query.Between("INET_ATON(remoteAddr)", utils.IP2Long(rangeConfig.IPFrom), utils.IP2Long(rangeConfig.IPTo)) query.Between("INET_ATON(remoteAddr)", iputils.ToLong(rangeConfig.IPFrom), iputils.ToLong(rangeConfig.IPTo))
} }
} }
} else { } else {
@@ -575,7 +580,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx,
if len(pieces) == 1 || len(pieces[1]) == 0 || pieces[0] == pieces[1] { if len(pieces) == 1 || len(pieces[1]) == 0 || pieces[0] == pieces[1] {
query.Attr("remoteAddr", pieces[0]) query.Attr("remoteAddr", pieces[0])
} else { } else {
query.Between("INET_ATON(remoteAddr)", utils.IP2Long(pieces[0]), utils.IP2Long(pieces[1])) query.Between("INET_ATON(remoteAddr)", iputils.ToLong(pieces[0]), iputils.ToLong(pieces[1]))
} }
} else if statusRangeReg.MatchString(keyword) { // status:200-400 } else if statusRangeReg.MatchString(keyword) { // status:200-400
isSpecialKeyword = true isSpecialKeyword = true
@@ -613,6 +618,11 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx,
var matches = methodReg.FindStringSubmatch(keyword) var matches = methodReg.FindStringSubmatch(keyword)
query.Where("JSON_EXTRACT(content, '$.requestMethod')=:keyword"). query.Where("JSON_EXTRACT(content, '$.requestMethod')=:keyword").
Param("keyword", strings.ToUpper(matches[1])) Param("keyword", strings.ToUpper(matches[1]))
} else if refererReg.MatchString(keyword) {
isSpecialKeyword = true
var matches = refererReg.FindStringSubmatch(keyword)
query.Where("JSON_EXTRACT(content, '$.referer') LIKE :keyword").
Param("keyword", dbutils.QuoteLike(matches[1]))
} }
if !isSpecialKeyword { if !isSpecialKeyword {
if regexp.MustCompile(`^ip:.+`).MatchString(keyword) { if regexp.MustCompile(`^ip:.+`).MatchString(keyword) {
@@ -857,8 +867,4 @@ func (this *HTTPAccessLogDAO) SetupQueue() {
oldAccessLogQueue = accessLogQueue oldAccessLogQueue = accessLogQueue
accessLogQueue = make(chan *pb.HTTPAccessLog, config.MaxLength) accessLogQueue = make(chan *pb.HTTPAccessLog, config.MaxLength)
} }
if Tea.IsTesting() {
remotelogs.Println("HTTP_ACCESS_LOG_QUEUE", "change queue "+string(configJSON))
}
} }

View File

@@ -41,7 +41,7 @@ func (this *HTTPAccessLogManager) FindTableNames(db *dbs.DB, day string) ([]stri
for _, prefix := range []string{"edgeHTTPAccessLogs_" + day + "%", "edgehttpaccesslogs_" + day + "%"} { for _, prefix := range []string{"edgeHTTPAccessLogs_" + day + "%", "edgehttpaccesslogs_" + day + "%"} {
ones, columnNames, err := db.FindPreparedOnes(`SHOW TABLES LIKE '` + prefix + `'`) ones, columnNames, err := db.FindPreparedOnes(`SHOW TABLES LIKE '` + prefix + `'`)
if err != nil { if err != nil {
return nil, errors.New("query table names error: " + err.Error()) return nil, fmt.Errorf("query table names error: %w", err)
} }
var columnName = columnNames[0] var columnName = columnNames[0]
@@ -88,7 +88,7 @@ func (this *HTTPAccessLogManager) FindTables(db *dbs.DB, day string) ([]*httpAcc
for _, prefix := range []string{"edgeHTTPAccessLogs_" + day + "%", "edgehttpaccesslogs_" + day + "%"} { for _, prefix := range []string{"edgeHTTPAccessLogs_" + day + "%", "edgehttpaccesslogs_" + day + "%"} {
ones, columnNames, err := db.FindPreparedOnes(`SHOW TABLES LIKE '` + prefix + `'`) ones, columnNames, err := db.FindPreparedOnes(`SHOW TABLES LIKE '` + prefix + `'`)
if err != nil { if err != nil {
return nil, errors.New("query table names error: " + err.Error()) return nil, fmt.Errorf("query table names error: %w", err)
} }
var columnName = columnNames[0] var columnName = columnNames[0]
@@ -239,7 +239,7 @@ func (this *HTTPAccessLogManager) FindLastTable(db *dbs.DB, day string, force bo
// CreateTable 创建访问日志表格 // CreateTable 创建访问日志表格
func (this *HTTPAccessLogManager) CreateTable(db *dbs.DB, tableName string) error { func (this *HTTPAccessLogManager) CreateTable(db *dbs.DB, tableName string) error {
_, err := db.Exec("CREATE TABLE `" + tableName + "` (\n `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'ID',\n `serverId` int(11) unsigned DEFAULT '0' COMMENT '服务ID',\n `nodeId` int(11) unsigned DEFAULT '0' COMMENT '节点ID',\n `status` int(3) unsigned DEFAULT '0' COMMENT '状态码',\n `createdAt` bigint(11) unsigned DEFAULT '0' COMMENT '创建时间',\n `content` json DEFAULT NULL COMMENT '日志内容',\n `requestId` varchar(128) DEFAULT NULL COMMENT '请求ID',\n `firewallPolicyId` int(11) unsigned DEFAULT '0' COMMENT 'WAF策略ID',\n `firewallRuleGroupId` int(11) unsigned DEFAULT '0' COMMENT 'WAF分组ID',\n `firewallRuleSetId` int(11) unsigned DEFAULT '0' COMMENT 'WAF集ID',\n `firewallRuleId` int(11) unsigned DEFAULT '0' COMMENT 'WAF规则ID',\n `remoteAddr` varchar(64) DEFAULT NULL COMMENT 'IP地址',\n `domain` varchar(128) DEFAULT NULL COMMENT '域名',\n `requestBody` mediumblob COMMENT '请求内容',\n `responseBody` mediumblob COMMENT '响应内容',\n PRIMARY KEY (`id`),\n KEY `serverId` (`serverId`),\n KEY `nodeId` (`nodeId`),\n KEY `serverId_status` (`serverId`,`status`),\n KEY `requestId` (`requestId`),\n KEY `firewallPolicyId` (`firewallPolicyId`),\n KEY `firewallRuleGroupId` (`firewallRuleGroupId`),\n KEY `firewallRuleSetId` (`firewallRuleSetId`),\n KEY `firewallRuleId` (`firewallRuleId`),\n KEY `remoteAddr` (`remoteAddr`),\n KEY `domain` (`domain`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='访问日志';") _, err := db.Exec("CREATE TABLE `" + tableName + "` (\n `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'ID',\n `serverId` int(11) unsigned DEFAULT '0' COMMENT '服务ID',\n `nodeId` int(11) unsigned DEFAULT '0' COMMENT '节点ID',\n `status` int(3) unsigned DEFAULT '0' COMMENT '状态码',\n `createdAt` bigint(11) unsigned DEFAULT '0' COMMENT '创建时间',\n `content` json DEFAULT NULL COMMENT '日志内容',\n `requestId` varchar(128) DEFAULT NULL COMMENT '请求ID',\n `firewallPolicyId` int(11) unsigned DEFAULT '0' COMMENT 'WAF策略ID',\n `firewallRuleGroupId` int(11) unsigned DEFAULT '0' COMMENT 'WAF分组ID',\n `firewallRuleSetId` int(11) unsigned DEFAULT '0' COMMENT 'WAF集ID',\n `firewallRuleId` int(11) unsigned DEFAULT '0' COMMENT 'WAF规则ID',\n `remoteAddr` varchar(64) DEFAULT NULL COMMENT 'IP地址',\n `domain` varchar(255) DEFAULT NULL COMMENT '域名',\n `requestBody` mediumblob COMMENT '请求内容',\n `responseBody` mediumblob COMMENT '响应内容',\n PRIMARY KEY (`id`),\n KEY `serverId` (`serverId`),\n KEY `nodeId` (`nodeId`),\n KEY `serverId_status` (`serverId`,`status`),\n KEY `requestId` (`requestId`),\n KEY `firewallPolicyId` (`firewallPolicyId`),\n KEY `firewallRuleGroupId` (`firewallRuleGroupId`),\n KEY `firewallRuleSetId` (`firewallRuleSetId`),\n KEY `firewallRuleId` (`firewallRuleId`),\n KEY `remoteAddr` (`remoteAddr`),\n KEY `domain` (`domain`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='访问日志';")
if err != nil { if err != nil {
if CheckSQLErrCode(err, 1050) { // Error 1050: Table 'xxx' already exists if CheckSQLErrCode(err, 1050) { // Error 1050: Table 'xxx' already exists
return nil return nil
@@ -373,7 +373,7 @@ func (this *HTTPAccessLogManager) findTableWithoutCache(db *dbs.DB, day string,
var lastInt64Id = types.Int64(lastId) var lastInt64Id = types.Int64(lastId)
if accessLogRowsPerTable > 0 && lastInt64Id >= accessLogRowsPerTable { if accessLogRowsPerTable > 0 && lastInt64Id >= accessLogRowsPerTable {
// create next partial table // create next partial table
var nextTableName = "" var nextTableName string
if accessLogTableMainReg.MatchString(lastTableName) { if accessLogTableMainReg.MatchString(lastTableName) {
nextTableName = prefix + "_0001" nextTableName = prefix + "_0001"
} else if accessLogTablePartialReg.MatchString(lastTableName) { } else if accessLogTablePartialReg.MatchString(lastTableName) {

View File

@@ -107,7 +107,7 @@ func (this *HTTPAccessLogPolicyDAO) FindAllEnabledAndOnPolicies(tx *dbs.Tx) (res
} }
// CreatePolicy 创建策略 // CreatePolicy 创建策略
func (this *HTTPAccessLogPolicyDAO) CreatePolicy(tx *dbs.Tx, name string, policyType string, optionsJSON []byte, condsJSON []byte, isPublic bool, firewallOnly bool) (policyId int64, err error) { func (this *HTTPAccessLogPolicyDAO) CreatePolicy(tx *dbs.Tx, name string, policyType string, optionsJSON []byte, condsJSON []byte, isPublic bool, firewallOnly bool, disableDefaultDB bool) (policyId int64, err error) {
var op = NewHTTPAccessLogPolicyOperator() var op = NewHTTPAccessLogPolicyOperator()
op.Name = name op.Name = name
op.Type = policyType op.Type = policyType
@@ -120,12 +120,13 @@ func (this *HTTPAccessLogPolicyDAO) CreatePolicy(tx *dbs.Tx, name string, policy
op.IsPublic = isPublic op.IsPublic = isPublic
op.IsOn = true op.IsOn = true
op.FirewallOnly = firewallOnly op.FirewallOnly = firewallOnly
op.DisableDefaultDB = disableDefaultDB
op.State = HTTPAccessLogPolicyStateEnabled op.State = HTTPAccessLogPolicyStateEnabled
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
} }
// UpdatePolicy 修改策略 // UpdatePolicy 修改策略
func (this *HTTPAccessLogPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, name string, optionsJSON []byte, condsJSON []byte, isPublic bool, firewallOnly bool, isOn bool) error { func (this *HTTPAccessLogPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, name string, optionsJSON []byte, condsJSON []byte, isPublic bool, firewallOnly bool, disableDefaultDB bool, isOn bool) error {
if policyId <= 0 { if policyId <= 0 {
return errors.New("invalid policyId") return errors.New("invalid policyId")
} }
@@ -159,6 +160,7 @@ func (this *HTTPAccessLogPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, nam
op.IsPublic = isPublic op.IsPublic = isPublic
op.FirewallOnly = firewallOnly op.FirewallOnly = firewallOnly
op.DisableDefaultDB = disableDefaultDB
op.IsOn = isOn op.IsOn = isOn
return this.Save(tx, op) return this.Save(tx, op)
} }

View File

@@ -2,39 +2,59 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
HTTPAccessLogPolicyField_Id dbs.FieldName = "id" // ID
HTTPAccessLogPolicyField_TemplateId dbs.FieldName = "templateId" // 模版ID
HTTPAccessLogPolicyField_AdminId dbs.FieldName = "adminId" // 管理员ID
HTTPAccessLogPolicyField_UserId dbs.FieldName = "userId" // 用户ID
HTTPAccessLogPolicyField_State dbs.FieldName = "state" // 状态
HTTPAccessLogPolicyField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
HTTPAccessLogPolicyField_Name dbs.FieldName = "name" // 名称
HTTPAccessLogPolicyField_IsOn dbs.FieldName = "isOn" // 是否启用
HTTPAccessLogPolicyField_Type dbs.FieldName = "type" // 存储类型
HTTPAccessLogPolicyField_Options dbs.FieldName = "options" // 存储选项
HTTPAccessLogPolicyField_Conds dbs.FieldName = "conds" // 请求条件
HTTPAccessLogPolicyField_IsPublic dbs.FieldName = "isPublic" // 是否为公用
HTTPAccessLogPolicyField_FirewallOnly dbs.FieldName = "firewallOnly" // 是否只记录防火墙相关
HTTPAccessLogPolicyField_Version dbs.FieldName = "version" // 版本号
HTTPAccessLogPolicyField_DisableDefaultDB dbs.FieldName = "disableDefaultDB" // 是否停止默认数据库存储
)
// HTTPAccessLogPolicy 访问日志策略 // HTTPAccessLogPolicy 访问日志策略
type HTTPAccessLogPolicy struct { type HTTPAccessLogPolicy struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
TemplateId uint32 `field:"templateId"` // 模版ID TemplateId uint32 `field:"templateId"` // 模版ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID UserId uint32 `field:"userId"` // 用户ID
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
IsOn bool `field:"isOn"` // 是否启用 IsOn bool `field:"isOn"` // 是否启用
Type string `field:"type"` // 存储类型 Type string `field:"type"` // 存储类型
Options dbs.JSON `field:"options"` // 存储选项 Options dbs.JSON `field:"options"` // 存储选项
Conds dbs.JSON `field:"conds"` // 请求条件 Conds dbs.JSON `field:"conds"` // 请求条件
IsPublic bool `field:"isPublic"` // 是否为公用 IsPublic bool `field:"isPublic"` // 是否为公用
FirewallOnly uint8 `field:"firewallOnly"` // 是否只记录防火墙相关 FirewallOnly uint8 `field:"firewallOnly"` // 是否只记录防火墙相关
Version uint32 `field:"version"` // 版本号 Version uint32 `field:"version"` // 版本号
DisableDefaultDB bool `field:"disableDefaultDB"` // 是否停止默认数据库存储
} }
type HTTPAccessLogPolicyOperator struct { type HTTPAccessLogPolicyOperator struct {
Id interface{} // ID Id any // ID
TemplateId interface{} // 模版ID TemplateId any // 模版ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
State interface{} // 状态 State any // 状态
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
Name interface{} // 名称 Name any // 名称
IsOn interface{} // 是否启用 IsOn any // 是否启用
Type interface{} // 存储类型 Type any // 存储类型
Options interface{} // 存储选项 Options any // 存储选项
Conds interface{} // 请求条件 Conds any // 请求条件
IsPublic interface{} // 是否为公用 IsPublic any // 是否为公用
FirewallOnly interface{} // 是否只记录防火墙相关 FirewallOnly any // 是否只记录防火墙相关
Version interface{} // 版本号 Version any // 版本号
DisableDefaultDB any // 是否停止默认数据库存储
} }
func NewHTTPAccessLogPolicyOperator() *HTTPAccessLogPolicyOperator { func NewHTTPAccessLogPolicyOperator() *HTTPAccessLogPolicyOperator {

View File

@@ -96,7 +96,7 @@ func (this *HTTPCachePolicyDAO) FindAllEnabledCachePolicies(tx *dbs.Tx) (result
} }
// CreateCachePolicy 创建缓存策略 // CreateCachePolicy 创建缓存策略
func (this *HTTPCachePolicyDAO) CreateCachePolicy(tx *dbs.Tx, isOn bool, name string, description string, capacityJSON []byte, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte, syncCompressionCache bool) (int64, error) { func (this *HTTPCachePolicyDAO) CreateCachePolicy(tx *dbs.Tx, isOn bool, name string, description string, capacityJSON []byte, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte, syncCompressionCache bool, fetchTimeoutJSON []byte) (int64, error) {
var op = NewHTTPCachePolicyOperator() var op = NewHTTPCachePolicyOperator()
op.State = HTTPCachePolicyStateEnabled op.State = HTTPCachePolicyStateEnabled
op.IsOn = isOn op.IsOn = isOn
@@ -114,6 +114,10 @@ func (this *HTTPCachePolicyDAO) CreateCachePolicy(tx *dbs.Tx, isOn bool, name st
} }
op.SyncCompressionCache = syncCompressionCache op.SyncCompressionCache = syncCompressionCache
if len(fetchTimeoutJSON) > 0 {
op.FetchTimeout = fetchTimeoutJSON
}
// 默认的缓存条件 // 默认的缓存条件
cacheRef := &serverconfigs.HTTPCacheRef{ cacheRef := &serverconfigs.HTTPCacheRef{
IsOn: true, IsOn: true,
@@ -170,7 +174,8 @@ func (this *HTTPCachePolicyDAO) CreateDefaultCachePolicy(tx *dbs.Tx, name string
} }
var storageOptions = &serverconfigs.HTTPFileCacheStorage{ var storageOptions = &serverconfigs.HTTPFileCacheStorage{
Dir: "/opt/cache", Dir: "/opt/cache",
EnableMMAP: true,
MemoryPolicy: &serverconfigs.HTTPCachePolicy{ MemoryPolicy: &serverconfigs.HTTPCachePolicy{
Capacity: &shared.SizeCapacity{ Capacity: &shared.SizeCapacity{
Count: 1, Count: 1,
@@ -183,7 +188,7 @@ func (this *HTTPCachePolicyDAO) CreateDefaultCachePolicy(tx *dbs.Tx, name string
return 0, err return 0, err
} }
policyId, err := this.CreateCachePolicy(tx, true, "\""+name+"\"缓存策略", "默认创建的缓存策略", capacityJSON, maxSizeJSON, serverconfigs.CachePolicyStorageFile, storageOptionsJSON, false) policyId, err := this.CreateCachePolicy(tx, true, "\""+name+"\"缓存策略", "默认创建的缓存策略", capacityJSON, maxSizeJSON, serverconfigs.CachePolicyStorageFile, storageOptionsJSON, false, nil)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -191,7 +196,7 @@ func (this *HTTPCachePolicyDAO) CreateDefaultCachePolicy(tx *dbs.Tx, name string
} }
// UpdateCachePolicy 修改缓存策略 // UpdateCachePolicy 修改缓存策略
func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte, syncCompressionCache bool) error { func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte, syncCompressionCache bool, fetchTimeoutJSON []byte) error {
if policyId <= 0 { if policyId <= 0 {
return errors.New("invalid policyId") return errors.New("invalid policyId")
} }
@@ -212,6 +217,9 @@ func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, is
op.Options = storageOptionsJSON op.Options = storageOptionsJSON
} }
op.SyncCompressionCache = syncCompressionCache op.SyncCompressionCache = syncCompressionCache
if len(fetchTimeoutJSON) > 0 {
op.FetchTimeout = fetchTimeoutJSON
}
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return err return err
@@ -237,7 +245,7 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
if policy == nil { if policy == nil {
return nil, nil return nil, nil
} }
config := &serverconfigs.HTTPCachePolicy{} var config = &serverconfigs.HTTPCachePolicy{}
config.Id = int64(policy.Id) config.Id = int64(policy.Id)
config.IsOn = policy.IsOn config.IsOn = policy.IsOn
config.Name = policy.Name config.Name = policy.Name
@@ -246,7 +254,7 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
// capacity // capacity
if IsNotNull(policy.Capacity) { if IsNotNull(policy.Capacity) {
capacityConfig := &shared.SizeCapacity{} var capacityConfig = &shared.SizeCapacity{}
err = json.Unmarshal(policy.Capacity, capacityConfig) err = json.Unmarshal(policy.Capacity, capacityConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -256,7 +264,7 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
// max size // max size
if IsNotNull(policy.MaxSize) { if IsNotNull(policy.MaxSize) {
maxSizeConfig := &shared.SizeCapacity{} var maxSizeConfig = &shared.SizeCapacity{}
err = json.Unmarshal(policy.MaxSize, maxSizeConfig) err = json.Unmarshal(policy.MaxSize, maxSizeConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -268,7 +276,7 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
// options // options
if IsNotNull(policy.Options) { if IsNotNull(policy.Options) {
m := map[string]interface{}{} var m = map[string]any{}
err = json.Unmarshal(policy.Options, &m) err = json.Unmarshal(policy.Options, &m)
if err != nil { if err != nil {
return nil, errors.Wrap(err) return nil, errors.Wrap(err)
@@ -278,7 +286,7 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
// refs // refs
if IsNotNull(policy.Refs) { if IsNotNull(policy.Refs) {
refs := []*serverconfigs.HTTPCacheRef{} var refs = []*serverconfigs.HTTPCacheRef{}
err = json.Unmarshal(policy.Refs, &refs) err = json.Unmarshal(policy.Refs, &refs)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -286,6 +294,16 @@ func (this *HTTPCachePolicyDAO) ComposeCachePolicy(tx *dbs.Tx, policyId int64, c
config.CacheRefs = refs config.CacheRefs = refs
} }
// fetch timeout
if IsNotNull(policy.FetchTimeout) {
var timeoutDuration = &shared.TimeDuration{}
err = json.Unmarshal(policy.FetchTimeout, timeoutDuration)
if err != nil {
return nil, err
}
config.FetchTimeout = timeoutDuration
}
if cacheMap != nil { if cacheMap != nil {
cacheMap.Put(cacheKey, config) cacheMap.Put(cacheKey, config)
} }

View File

@@ -2,6 +2,26 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
HTTPCachePolicyField_Id dbs.FieldName = "id" // ID
HTTPCachePolicyField_AdminId dbs.FieldName = "adminId" // 管理员ID
HTTPCachePolicyField_UserId dbs.FieldName = "userId" // 用户ID
HTTPCachePolicyField_TemplateId dbs.FieldName = "templateId" // 模版ID
HTTPCachePolicyField_IsOn dbs.FieldName = "isOn" // 是否启用
HTTPCachePolicyField_Name dbs.FieldName = "name" // 名称
HTTPCachePolicyField_Capacity dbs.FieldName = "capacity" // 容量数据
HTTPCachePolicyField_MaxKeys dbs.FieldName = "maxKeys" // 最多Key值
HTTPCachePolicyField_MaxSize dbs.FieldName = "maxSize" // 最大缓存内容尺寸
HTTPCachePolicyField_Type dbs.FieldName = "type" // 存储类型
HTTPCachePolicyField_Options dbs.FieldName = "options" // 存储选项
HTTPCachePolicyField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
HTTPCachePolicyField_State dbs.FieldName = "state" // 状态
HTTPCachePolicyField_Description dbs.FieldName = "description" // 描述
HTTPCachePolicyField_Refs dbs.FieldName = "refs" // 默认的缓存设置
HTTPCachePolicyField_SyncCompressionCache dbs.FieldName = "syncCompressionCache" // 是否同步写入压缩缓存
HTTPCachePolicyField_FetchTimeout dbs.FieldName = "fetchTimeout" // 预热超时时间
)
// HTTPCachePolicy HTTP缓存策略 // HTTPCachePolicy HTTP缓存策略
type HTTPCachePolicy struct { type HTTPCachePolicy struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
@@ -20,25 +40,27 @@ type HTTPCachePolicy struct {
Description string `field:"description"` // 描述 Description string `field:"description"` // 描述
Refs dbs.JSON `field:"refs"` // 默认的缓存设置 Refs dbs.JSON `field:"refs"` // 默认的缓存设置
SyncCompressionCache uint8 `field:"syncCompressionCache"` // 是否同步写入压缩缓存 SyncCompressionCache uint8 `field:"syncCompressionCache"` // 是否同步写入压缩缓存
FetchTimeout dbs.JSON `field:"fetchTimeout"` // 预热超时时间
} }
type HTTPCachePolicyOperator struct { type HTTPCachePolicyOperator struct {
Id interface{} // ID Id any // ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
TemplateId interface{} // 模版ID TemplateId any // 模版ID
IsOn interface{} // 是否启用 IsOn any // 是否启用
Name interface{} // 名称 Name any // 名称
Capacity interface{} // 容量数据 Capacity any // 容量数据
MaxKeys interface{} // 最多Key值 MaxKeys any // 最多Key值
MaxSize interface{} // 最大缓存内容尺寸 MaxSize any // 最大缓存内容尺寸
Type interface{} // 存储类型 Type any // 存储类型
Options interface{} // 存储选项 Options any // 存储选项
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
State interface{} // 状态 State any // 状态
Description interface{} // 描述 Description any // 描述
Refs interface{} // 默认的缓存设置 Refs any // 默认的缓存设置
SyncCompressionCache interface{} // 是否同步写入压缩缓存 SyncCompressionCache any // 是否同步写入压缩缓存
FetchTimeout any // 预热超时时间
} }
func NewHTTPCachePolicyOperator() *HTTPCachePolicyOperator { func NewHTTPCachePolicyOperator() *HTTPCachePolicyOperator {

View File

@@ -134,7 +134,7 @@ func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64
if userId <= 0 && serverGroupId <= 0 && serverId <= 0 { if userId <= 0 && serverGroupId <= 0 && serverId <= 0 {
// synFlood // synFlood
var synFloodConfig = firewallconfigs.DefaultSYNFloodConfig() var synFloodConfig = firewallconfigs.NewSYNFloodConfig()
synFloodJSON, err := json.Marshal(synFloodConfig) synFloodJSON, err := json.Marshal(synFloodConfig)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -142,20 +142,36 @@ func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(tx *dbs.Tx, userId int64
op.SynFlood = synFloodJSON op.SynFlood = synFloodJSON
// block options // block options
var blockOptions = firewallconfigs.DefaultHTTPFirewallBlockAction() var blockOptions = firewallconfigs.NewHTTPFirewallBlockAction()
blockOptionsJSON, err := json.Marshal(blockOptions) blockOptionsJSON, err := json.Marshal(blockOptions)
if err != nil { if err != nil {
return 0, err return 0, err
} }
op.BlockOptions = blockOptionsJSON op.BlockOptions = blockOptionsJSON
// page options
var pageOptions = firewallconfigs.NewHTTPFirewallPageAction()
pageOptionsJSON, err := json.Marshal(pageOptions)
if err != nil {
return 0, err
}
op.PageOptions = pageOptionsJSON
// captcha options // captcha options
var captchaOptions = firewallconfigs.DefaultHTTPFirewallCaptchaAction() var captchaOptions = firewallconfigs.NewHTTPFirewallCaptchaAction()
captchaOptionsJSON, err := json.Marshal(captchaOptions) captchaOptionsJSON, err := json.Marshal(captchaOptions)
if err != nil { if err != nil {
return 0, err return 0, err
} }
op.CaptchaOptions = captchaOptionsJSON op.CaptchaOptions = captchaOptionsJSON
// jscookie options
var jsCookieOptions = firewallconfigs.NewHTTPFirewallJavascriptCookieAction()
jsCookieOptionsJSON, err := json.Marshal(jsCookieOptions)
if err != nil {
return 0, err
}
op.JsCookieOptions = jsCookieOptionsJSON
} }
err := this.Save(tx, op) err := this.Save(tx, op)
@@ -172,16 +188,18 @@ func (this *HTTPFirewallPolicyDAO) CreateDefaultFirewallPolicy(tx *dbs.Tx, name
// 初始化 // 初始化
var groupCodes = []string{} var groupCodes = []string{}
templatePolicy := firewallconfigs.HTTPFirewallTemplate() var templatePolicy = firewallconfigs.HTTPFirewallTemplate()
for _, group := range templatePolicy.AllRuleGroups() { for _, group := range templatePolicy.AllRuleGroups() {
groupCodes = append(groupCodes, group.Code) if group.IsOn {
groupCodes = append(groupCodes, group.Code)
}
} }
var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true} var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
if templatePolicy.Inbound != nil { if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups { for _, group := range templatePolicy.Inbound.Groups {
isOn := lists.ContainsString(groupCodes, group.Code) var isOn = lists.ContainsString(groupCodes, group.Code)
group.IsOn = isOn group.IsOn = isOn
groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
@@ -196,7 +214,7 @@ func (this *HTTPFirewallPolicyDAO) CreateDefaultFirewallPolicy(tx *dbs.Tx, name
} }
if templatePolicy.Outbound != nil { if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups { for _, group := range templatePolicy.Outbound.Groups {
isOn := lists.ContainsString(groupCodes, group.Code) var isOn = lists.ContainsString(groupCodes, group.Code)
group.IsOn = isOn group.IsOn = isOn
groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group) groupId, err := SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
@@ -277,6 +295,31 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInbound(tx *dbs.Tx, polic
return this.NotifyUpdate(tx, policyId) return this.NotifyUpdate(tx, policyId)
} }
// UpdateFirewallPolicyInboundRegion 修改入站封禁区域设置
func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundRegion(tx *dbs.Tx, policyId int64, regionConfig *firewallconfigs.HTTPFirewallRegionConfig) error {
var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
inboundJSON, err := this.Query(tx).
Pk(policyId).
Result("inbound").
FindJSONCol()
if err != nil {
return err
}
if IsNotNull(inboundJSON) {
err = json.Unmarshal(inboundJSON, inboundConfig)
if err != nil {
return err
}
}
inboundConfig.Region = regionConfig
newInboundJSON, err := json.Marshal(inboundConfig)
if err != nil {
return err
}
return this.UpdateFirewallPolicyInbound(tx, policyId, newInboundJSON)
}
// UpdateFirewallPolicy 修改策略 // UpdateFirewallPolicy 修改策略
func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx, func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx,
policyId int64, policyId int64,
@@ -286,11 +329,16 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx,
inboundJSON []byte, inboundJSON []byte,
outboundJSON []byte, outboundJSON []byte,
blockOptionsJSON []byte, blockOptionsJSON []byte,
pageOptionsJSON []byte,
captchaOptionsJSON []byte, captchaOptionsJSON []byte,
jsCookieOptionsJSON []byte,
mode firewallconfigs.FirewallMode, mode firewallconfigs.FirewallMode,
useLocalFirewall bool, useLocalFirewall bool,
synFloodConfig *firewallconfigs.SYNFloodConfig, synFloodConfig *firewallconfigs.SYNFloodConfig,
logConfig *firewallconfigs.HTTPFirewallPolicyLogConfig) error { logConfig *firewallconfigs.HTTPFirewallPolicyLogConfig,
maxRequestBodySize int64,
denyCountryHTML string,
denyProvinceHTML string) error {
if policyId <= 0 { if policyId <= 0 {
return errors.New("invalid policyId") return errors.New("invalid policyId")
} }
@@ -313,9 +361,15 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx,
if IsNotNull(blockOptionsJSON) { if IsNotNull(blockOptionsJSON) {
op.BlockOptions = blockOptionsJSON op.BlockOptions = blockOptionsJSON
} }
if IsNotNull(pageOptionsJSON) {
op.PageOptions = pageOptionsJSON
}
if IsNotNull(captchaOptionsJSON) { if IsNotNull(captchaOptionsJSON) {
op.CaptchaOptions = captchaOptionsJSON op.CaptchaOptions = captchaOptionsJSON
} }
if IsNotNull(jsCookieOptionsJSON) {
op.JsCookieOptions = jsCookieOptionsJSON
}
if synFloodConfig != nil { if synFloodConfig != nil {
synFloodConfigJSON, err := json.Marshal(synFloodConfig) synFloodConfigJSON, err := json.Marshal(synFloodConfig)
@@ -338,6 +392,10 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(tx *dbs.Tx,
} }
op.UseLocalFirewall = useLocalFirewall op.UseLocalFirewall = useLocalFirewall
op.MaxRequestBodySize = maxRequestBodySize
op.DenyCountryHTML = denyCountryHTML
op.DenyProvinceHTML = denyProvinceHTML
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return err return err
@@ -390,7 +448,7 @@ func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(tx *dbs.Tx, clust
} }
// ComposeFirewallPolicy 组合策略配置 // ComposeFirewallPolicy 组合策略配置
func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId int64, cacheMap *utils.CacheMap) (*firewallconfigs.HTTPFirewallPolicy, error) { func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId int64, forNode bool, cacheMap *utils.CacheMap) (*firewallconfigs.HTTPFirewallPolicy, error) {
if cacheMap == nil { if cacheMap == nil {
cacheMap = utils.NewCacheMap() cacheMap = utils.NewCacheMap()
} }
@@ -410,10 +468,14 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
var config = &firewallconfigs.HTTPFirewallPolicy{} var config = &firewallconfigs.HTTPFirewallPolicy{}
config.Id = int64(policy.Id) config.Id = int64(policy.Id)
config.ServerId = int64(policy.ServerId)
config.IsOn = policy.IsOn config.IsOn = policy.IsOn
config.Name = policy.Name config.Name = policy.Name
config.Description = policy.Description config.Description = policy.Description
config.UseLocalFirewall = policy.UseLocalFirewall == 1 config.UseLocalFirewall = policy.UseLocalFirewall == 1
config.MaxRequestBodySize = int64(policy.MaxRequestBodySize)
config.DenyCountryHTML = policy.DenyCountryHTML
config.DenyProvinceHTML = policy.DenyProvinceHTML
if len(policy.Mode) == 0 { if len(policy.Mode) == 0 {
policy.Mode = firewallconfigs.FirewallModeDefend policy.Mode = firewallconfigs.FirewallModeDefend
@@ -421,18 +483,18 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
config.Mode = policy.Mode config.Mode = policy.Mode
// Inbound // Inbound
inbound := &firewallconfigs.HTTPFirewallInboundConfig{} var inbound = &firewallconfigs.HTTPFirewallInboundConfig{}
if IsNotNull(policy.Inbound) { if IsNotNull(policy.Inbound) {
err = json.Unmarshal(policy.Inbound, inbound) err = json.Unmarshal(policy.Inbound, inbound)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(inbound.GroupRefs) > 0 { if len(inbound.GroupRefs) > 0 {
resultGroupRefs := []*firewallconfigs.HTTPFirewallRuleGroupRef{} var resultGroupRefs = []*firewallconfigs.HTTPFirewallRuleGroupRef{}
resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{} var resultGroups = []*firewallconfigs.HTTPFirewallRuleGroup{}
for _, groupRef := range inbound.GroupRefs { for _, groupRef := range inbound.GroupRefs {
groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId) groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId, forNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -449,18 +511,18 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
config.Inbound = inbound config.Inbound = inbound
// Outbound // Outbound
outbound := &firewallconfigs.HTTPFirewallOutboundConfig{} var outbound = &firewallconfigs.HTTPFirewallOutboundConfig{}
if IsNotNull(policy.Outbound) { if IsNotNull(policy.Outbound) {
err = json.Unmarshal(policy.Outbound, outbound) err = json.Unmarshal(policy.Outbound, outbound)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(outbound.GroupRefs) > 0 { if len(outbound.GroupRefs) > 0 {
resultGroupRefs := []*firewallconfigs.HTTPFirewallRuleGroupRef{} var resultGroupRefs = []*firewallconfigs.HTTPFirewallRuleGroupRef{}
resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{} var resultGroups = []*firewallconfigs.HTTPFirewallRuleGroup{}
for _, groupRef := range outbound.GroupRefs { for _, groupRef := range outbound.GroupRefs {
groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId) groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, groupRef.GroupId, forNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -478,7 +540,7 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
// Block动作配置 // Block动作配置
if IsNotNull(policy.BlockOptions) { if IsNotNull(policy.BlockOptions) {
var blockAction = &firewallconfigs.HTTPFirewallBlockAction{} var blockAction = firewallconfigs.NewHTTPFirewallBlockAction()
err = json.Unmarshal(policy.BlockOptions, blockAction) err = json.Unmarshal(policy.BlockOptions, blockAction)
if err != nil { if err != nil {
return config, err return config, err
@@ -486,9 +548,19 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
config.BlockOptions = blockAction config.BlockOptions = blockAction
} }
// Page动作配置
if IsNotNull(policy.PageOptions) {
var pageAction = firewallconfigs.NewHTTPFirewallPageAction()
err = json.Unmarshal(policy.PageOptions, pageAction)
if err != nil {
return config, err
}
config.PageOptions = pageAction
}
// Captcha动作配置 // Captcha动作配置
if IsNotNull(policy.CaptchaOptions) { if IsNotNull(policy.CaptchaOptions) {
var captchaAction = &firewallconfigs.HTTPFirewallCaptchaAction{} var captchaAction = firewallconfigs.NewHTTPFirewallCaptchaAction()
err = json.Unmarshal(policy.CaptchaOptions, captchaAction) err = json.Unmarshal(policy.CaptchaOptions, captchaAction)
if err != nil { if err != nil {
return config, err return config, err
@@ -496,6 +568,16 @@ func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(tx *dbs.Tx, policyId in
config.CaptchaOptions = captchaAction config.CaptchaOptions = captchaAction
} }
// JSCookie动作配置
if IsNotNull(policy.JsCookieOptions) {
var jsCookieAction = firewallconfigs.NewHTTPFirewallJavascriptCookieAction()
err = json.Unmarshal(policy.JsCookieOptions, jsCookieAction)
if err != nil {
return config, err
}
config.JSCookieOptions = jsCookieAction
}
// syn flood // syn flood
if IsNotNull(policy.SynFlood) { if IsNotNull(policy.SynFlood) {
var synFloodConfig = &firewallconfigs.SYNFloodConfig{} var synFloodConfig = &firewallconfigs.SYNFloodConfig{}
@@ -630,6 +712,19 @@ func (this *HTTPFirewallPolicyDAO) FindFirewallPolicyIdsWithServerId(tx *dbs.Tx,
return result, nil return result, nil
} }
// FindServerIdWithFirewallPolicyId 根据策略查找网站ID
func (this *HTTPFirewallPolicyDAO) FindServerIdWithFirewallPolicyId(tx *dbs.Tx, policyId int64) (serverId int64, err error) {
if policyId <= 0 {
return
}
serverId, err = this.Query(tx).
Pk(policyId).
Result("serverId").
FindInt64Col(0)
return
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error {
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId)

View File

@@ -2,49 +2,86 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
HTTPFirewallPolicyField_Id dbs.FieldName = "id" // ID
HTTPFirewallPolicyField_TemplateId dbs.FieldName = "templateId" // 模版ID
HTTPFirewallPolicyField_AdminId dbs.FieldName = "adminId" // 管理员ID
HTTPFirewallPolicyField_UserId dbs.FieldName = "userId" // 用户ID
HTTPFirewallPolicyField_ServerId dbs.FieldName = "serverId" // 服务ID
HTTPFirewallPolicyField_GroupId dbs.FieldName = "groupId" // 服务分组ID
HTTPFirewallPolicyField_State dbs.FieldName = "state" // 状态
HTTPFirewallPolicyField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
HTTPFirewallPolicyField_IsOn dbs.FieldName = "isOn" // 是否启用
HTTPFirewallPolicyField_Name dbs.FieldName = "name" // 名称
HTTPFirewallPolicyField_Description dbs.FieldName = "description" // 描述
HTTPFirewallPolicyField_Inbound dbs.FieldName = "inbound" // 入站规则
HTTPFirewallPolicyField_Outbound dbs.FieldName = "outbound" // 出站规则
HTTPFirewallPolicyField_BlockOptions dbs.FieldName = "blockOptions" // BLOCK动作选项
HTTPFirewallPolicyField_PageOptions dbs.FieldName = "pageOptions" // PAGE动作选项
HTTPFirewallPolicyField_CaptchaOptions dbs.FieldName = "captchaOptions" // 验证码动作选项
HTTPFirewallPolicyField_JsCookieOptions dbs.FieldName = "jsCookieOptions" // JSCookie动作选项
HTTPFirewallPolicyField_Mode dbs.FieldName = "mode" // 模式
HTTPFirewallPolicyField_UseLocalFirewall dbs.FieldName = "useLocalFirewall" // 是否自动使用本地防火墙
HTTPFirewallPolicyField_SynFlood dbs.FieldName = "synFlood" // SynFlood防御设置
HTTPFirewallPolicyField_Log dbs.FieldName = "log" // 日志配置
HTTPFirewallPolicyField_MaxRequestBodySize dbs.FieldName = "maxRequestBodySize" // 可以检查的最大请求内容尺寸
HTTPFirewallPolicyField_DenyCountryHTML dbs.FieldName = "denyCountryHTML" // 区域封禁提示
HTTPFirewallPolicyField_DenyProvinceHTML dbs.FieldName = "denyProvinceHTML" // 省份封禁提示
)
// HTTPFirewallPolicy HTTP防火墙 // HTTPFirewallPolicy HTTP防火墙
type HTTPFirewallPolicy struct { type HTTPFirewallPolicy struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
TemplateId uint32 `field:"templateId"` // 模版ID TemplateId uint32 `field:"templateId"` // 模版ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID UserId uint32 `field:"userId"` // 用户ID
ServerId uint32 `field:"serverId"` // 服务ID ServerId uint32 `field:"serverId"` // 服务ID
GroupId uint32 `field:"groupId"` // 服务分组ID GroupId uint32 `field:"groupId"` // 服务分组ID
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
IsOn bool `field:"isOn"` // 是否启用 IsOn bool `field:"isOn"` // 是否启用
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
Description string `field:"description"` // 描述 Description string `field:"description"` // 描述
Inbound dbs.JSON `field:"inbound"` // 入站规则 Inbound dbs.JSON `field:"inbound"` // 入站规则
Outbound dbs.JSON `field:"outbound"` // 出站规则 Outbound dbs.JSON `field:"outbound"` // 出站规则
BlockOptions dbs.JSON `field:"blockOptions"` // BLOCK选项 BlockOptions dbs.JSON `field:"blockOptions"` // BLOCK动作选项
CaptchaOptions dbs.JSON `field:"captchaOptions"` // 验证码选项 PageOptions dbs.JSON `field:"pageOptions"` // PAGE动作选项
Mode string `field:"mode"` // 模式 CaptchaOptions dbs.JSON `field:"captchaOptions"` // 验证码动作选项
UseLocalFirewall uint8 `field:"useLocalFirewall"` // 是否自动使用本地防火墙 JsCookieOptions dbs.JSON `field:"jsCookieOptions"` // JSCookie动作选项
SynFlood dbs.JSON `field:"synFlood"` // SynFlood防御设置 Mode string `field:"mode"` // 模式
Log dbs.JSON `field:"log"` // 日志配置 UseLocalFirewall uint8 `field:"useLocalFirewall"` // 是否自动使用本地防火墙
SynFlood dbs.JSON `field:"synFlood"` // SynFlood防御设置
Log dbs.JSON `field:"log"` // 日志配置
MaxRequestBodySize uint32 `field:"maxRequestBodySize"` // 可以检查的最大请求内容尺寸
DenyCountryHTML string `field:"denyCountryHTML"` // 区域封禁提示
DenyProvinceHTML string `field:"denyProvinceHTML"` // 省份封禁提示
} }
type HTTPFirewallPolicyOperator struct { type HTTPFirewallPolicyOperator struct {
Id interface{} // ID Id any // ID
TemplateId interface{} // 模版ID TemplateId any // 模版ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
ServerId interface{} // 服务ID ServerId any // 服务ID
GroupId interface{} // 服务分组ID GroupId any // 服务分组ID
State interface{} // 状态 State any // 状态
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
IsOn interface{} // 是否启用 IsOn any // 是否启用
Name interface{} // 名称 Name any // 名称
Description interface{} // 描述 Description any // 描述
Inbound interface{} // 入站规则 Inbound any // 入站规则
Outbound interface{} // 出站规则 Outbound any // 出站规则
BlockOptions interface{} // BLOCK选项 BlockOptions any // BLOCK动作选项
CaptchaOptions interface{} // 验证码选项 PageOptions any // PAGE动作选项
Mode interface{} // 模式 CaptchaOptions any // 验证码动作选项
UseLocalFirewall interface{} // 是否自动使用本地防火墙 JsCookieOptions any // JSCookie动作选项
SynFlood interface{} // SynFlood防御设置 Mode any // 模式
Log interface{} // 日志配置 UseLocalFirewall any // 是否自动使用本地防火墙
SynFlood any // SynFlood防御设置
Log any // 日志配置
MaxRequestBodySize any // 可以检查的最大请求内容尺寸
DenyCountryHTML any // 区域封禁提示
DenyProvinceHTML any // 省份封禁提示
} }
func NewHTTPFirewallPolicyOperator() *HTTPFirewallPolicyOperator { func NewHTTPFirewallPolicyOperator() *HTTPFirewallPolicyOperator {

View File

@@ -81,7 +81,7 @@ func (this *HTTPFirewallRuleGroupDAO) FindHTTPFirewallRuleGroupName(tx *dbs.Tx,
} }
// ComposeFirewallRuleGroup 组合配置 // ComposeFirewallRuleGroup 组合配置
func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(tx *dbs.Tx, groupId int64) (*firewallconfigs.HTTPFirewallRuleGroup, error) { func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(tx *dbs.Tx, groupId int64, forNode bool) (*firewallconfigs.HTTPFirewallRuleGroup, error) {
group, err := this.FindEnabledHTTPFirewallRuleGroup(tx, groupId) group, err := this.FindEnabledHTTPFirewallRuleGroup(tx, groupId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -89,7 +89,7 @@ func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(tx *dbs.Tx, group
if group == nil { if group == nil {
return nil, nil return nil, nil
} }
config := &firewallconfigs.HTTPFirewallRuleGroup{} var config = &firewallconfigs.HTTPFirewallRuleGroup{}
config.Id = int64(group.Id) config.Id = int64(group.Id)
config.IsOn = group.IsOn config.IsOn = group.IsOn
config.Name = group.Name config.Name = group.Name
@@ -98,17 +98,17 @@ func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(tx *dbs.Tx, group
config.IsTemplate = group.IsTemplate config.IsTemplate = group.IsTemplate
if IsNotNull(group.Sets) { if IsNotNull(group.Sets) {
setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} var setRefs = []*firewallconfigs.HTTPFirewallRuleSetRef{}
err = json.Unmarshal(group.Sets, &setRefs) err = json.Unmarshal(group.Sets, &setRefs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, setRef := range setRefs { for _, setRef := range setRefs {
setConfig, err := SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(tx, setRef.SetId) setConfig, err := SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(tx, setRef.SetId, forNode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if setConfig != nil { if setConfig != nil && (!forNode || setConfig.IsOn) {
config.SetRefs = append(config.SetRefs, setRef) config.SetRefs = append(config.SetRefs, setRef)
config.Sets = append(config.Sets, setConfig) config.Sets = append(config.Sets, setConfig)
} }

View File

@@ -84,7 +84,7 @@ func (this *HTTPFirewallRuleSetDAO) FindHTTPFirewallRuleSetName(tx *dbs.Tx, id i
} }
// ComposeFirewallRuleSet 组合配置 // ComposeFirewallRuleSet 组合配置
func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int64) (*firewallconfigs.HTTPFirewallRuleSet, error) { func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int64, forNode bool) (*firewallconfigs.HTTPFirewallRuleSet, error) {
set, err := this.FindEnabledHTTPFirewallRuleSet(tx, setId) set, err := this.FindEnabledHTTPFirewallRuleSet(tx, setId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
if set == nil { if set == nil {
return nil, nil return nil, nil
} }
config := &firewallconfigs.HTTPFirewallRuleSet{} var config = &firewallconfigs.HTTPFirewallRuleSet{}
config.Id = int64(set.Id) config.Id = int64(set.Id)
config.IsOn = set.IsOn config.IsOn = set.IsOn
config.Name = set.Name config.Name = set.Name
@@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
config.IgnoreLocal = set.IgnoreLocal == 1 config.IgnoreLocal = set.IgnoreLocal == 1
if IsNotNull(set.Rules) { if IsNotNull(set.Rules) {
ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{} var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{}
err = json.Unmarshal(set.Rules, &ruleRefs) err = json.Unmarshal(set.Rules, &ruleRefs)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -128,6 +128,29 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
config.Actions = actionConfigs config.Actions = actionConfigs
} }
// 检查各个选项
for _, actionConfig := range actionConfigs {
if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作
if actionConfig.Options != nil {
var ipListId = actionConfig.Options.GetInt64("ipListId")
if ipListId <= 0 { // default list id
if forNode {
actionConfig.Options["ipListId"] = firewallconfigs.GlobalListId
}
actionConfig.Options["ipListIsDeleted"] = false
} else {
exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId)
if err != nil {
return nil, err
}
if !exists {
actionConfig.Options["ipListIsDeleted"] = true
}
}
}
}
}
return config, nil return config, nil
} }
@@ -212,6 +235,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r
FindInt64Col(0) FindInt64Col(0)
} }
// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集
func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) {
ones, err := this.Query(tx).
State(HTTPFirewallRuleStateEnabled).
Where("JSON_CONTAINS(actions, :jsonQuery)").
Param("jsonQuery", maps.Map{
"code": firewallconfigs.HTTPFirewallActionRecordIP,
"options": maps.Map{
"ipListId": ipListId,
},
}.AsJSON()).
ResultPk().
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id))
}
return
}
// CheckUserRuleSet 检查用户 // CheckUserRuleSet 检查用户
func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error { func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId) groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)

View File

@@ -77,7 +77,7 @@ func (this *HTTPPageDAO) FindEnabledHTTPPage(tx *dbs.Tx, id int64) (*HTTPPage, e
} }
// CreatePage 创建Page // CreatePage 创建Page
func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, userId int64, statusList []string, bodyType shared.BodyType, url string, body string, newStatus int) (pageId int64, err error) { func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, userId int64, statusList []string, bodyType serverconfigs.HTTPPageBodyType, url string, body string, newStatus int, exceptURLPatterns []*shared.URLPattern, onlyURLPatterns []*shared.URLPattern) (pageId int64, err error) {
var op = NewHTTPPageOperator() var op = NewHTTPPageOperator()
op.UserId = userId op.UserId = userId
op.IsOn = true op.IsOn = true
@@ -94,6 +94,29 @@ func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, userId int64, statusList []strin
op.Url = url op.Url = url
op.Body = body op.Body = body
op.NewStatus = newStatus op.NewStatus = newStatus
{
if exceptURLPatterns == nil {
exceptURLPatterns = []*shared.URLPattern{}
}
exceptURLPatternsJSON, err := json.Marshal(exceptURLPatterns)
if err != nil {
return 0, err
}
op.ExceptURLPatterns = exceptURLPatternsJSON
}
{
if onlyURLPatterns == nil {
onlyURLPatterns = []*shared.URLPattern{}
}
onlyURLPatternsJSON, err := json.Marshal(onlyURLPatterns)
if err != nil {
return 0, err
}
op.OnlyURLPatterns = onlyURLPatternsJSON
}
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -103,7 +126,7 @@ func (this *HTTPPageDAO) CreatePage(tx *dbs.Tx, userId int64, statusList []strin
} }
// UpdatePage 修改Page // UpdatePage 修改Page
func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []string, bodyType shared.BodyType, url string, body string, newStatus int) error { func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []string, bodyType serverconfigs.HTTPPageBodyType, url string, body string, newStatus int, exceptURLPatterns []*shared.URLPattern, onlyURLPatterns []*shared.URLPattern) error {
if pageId <= 0 { if pageId <= 0 {
return errors.New("invalid pageId") return errors.New("invalid pageId")
} }
@@ -126,6 +149,29 @@ func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []strin
op.Url = url op.Url = url
op.Body = body op.Body = body
op.NewStatus = newStatus op.NewStatus = newStatus
{
if exceptURLPatterns == nil {
exceptURLPatterns = []*shared.URLPattern{}
}
exceptURLPatternsJSON, err := json.Marshal(exceptURLPatterns)
if err != nil {
return err
}
op.ExceptURLPatterns = exceptURLPatternsJSON
}
{
if onlyURLPatterns == nil {
onlyURLPatterns = []*shared.URLPattern{}
}
onlyURLPatternsJSON, err := json.Marshal(onlyURLPatterns)
if err != nil {
return err
}
op.OnlyURLPatterns = onlyURLPatternsJSON
}
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
return err return err
@@ -156,6 +202,14 @@ func (this *HTTPPageDAO) ClonePage(tx *dbs.Tx, fromPageId int64) (newPageId int6
op.Body = page.Body op.Body = page.Body
op.BodyType = page.BodyType op.BodyType = page.BodyType
op.State = page.State op.State = page.State
if len(page.ExceptURLPatterns) > 0 {
op.ExceptURLPatterns = page.ExceptURLPatterns
}
if len(page.OnlyURLPatterns) > 0 {
op.OnlyURLPatterns = page.OnlyURLPatterns
}
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
} }
@@ -179,7 +233,7 @@ func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap *u
return nil, nil return nil, nil
} }
config := &serverconfigs.HTTPPageConfig{} var config = &serverconfigs.HTTPPageConfig{}
config.Id = int64(page.Id) config.Id = int64(page.Id)
config.IsOn = page.IsOn config.IsOn = page.IsOn
config.NewStatus = int(page.NewStatus) config.NewStatus = int(page.NewStatus)
@@ -188,7 +242,7 @@ func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap *u
config.BodyType = page.BodyType config.BodyType = page.BodyType
if len(page.BodyType) == 0 { if len(page.BodyType) == 0 {
page.BodyType = shared.BodyTypeURL page.BodyType = serverconfigs.HTTPPageBodyTypeURL
} }
if len(page.StatusList) > 0 { if len(page.StatusList) > 0 {
@@ -202,6 +256,28 @@ func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap *u
} }
} }
if len(page.ExceptURLPatterns) > 0 {
var exceptURLPatterns = []*shared.URLPattern{}
err = json.Unmarshal(page.ExceptURLPatterns, &exceptURLPatterns)
if err != nil {
return nil, err
}
if len(exceptURLPatterns) > 0 {
config.ExceptURLPatterns = exceptURLPatterns
}
}
if len(page.OnlyURLPatterns) > 0 {
var onlyURLPatterns = []*shared.URLPattern{}
err = json.Unmarshal(page.OnlyURLPatterns, &onlyURLPatterns)
if err != nil {
return nil, err
}
if len(onlyURLPatterns) > 0 {
config.OnlyURLPatterns = onlyURLPatterns
}
}
if cacheMap != nil { if cacheMap != nil {
cacheMap.Put(cacheKey, config) cacheMap.Put(cacheKey, config)
} }

View File

@@ -2,33 +2,53 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
HTTPPageField_Id dbs.FieldName = "id" // ID
HTTPPageField_AdminId dbs.FieldName = "adminId" // 管理员ID
HTTPPageField_UserId dbs.FieldName = "userId" // 用户ID
HTTPPageField_IsOn dbs.FieldName = "isOn" // 是否启用
HTTPPageField_StatusList dbs.FieldName = "statusList" // 状态列表
HTTPPageField_Url dbs.FieldName = "url" // 页面URL
HTTPPageField_NewStatus dbs.FieldName = "newStatus" // 新状态码
HTTPPageField_State dbs.FieldName = "state" // 状态
HTTPPageField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
HTTPPageField_Body dbs.FieldName = "body" // 页面内容
HTTPPageField_BodyType dbs.FieldName = "bodyType" // 内容类型
HTTPPageField_ExceptURLPatterns dbs.FieldName = "exceptURLPatterns" // 例外URL
HTTPPageField_OnlyURLPatterns dbs.FieldName = "onlyURLPatterns" // 限制URL
)
// HTTPPage 特殊页面 // HTTPPage 特殊页面
type HTTPPage struct { type HTTPPage struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID UserId uint32 `field:"userId"` // 用户ID
IsOn bool `field:"isOn"` // 是否启用 IsOn bool `field:"isOn"` // 是否启用
StatusList dbs.JSON `field:"statusList"` // 状态列表 StatusList dbs.JSON `field:"statusList"` // 状态列表
Url string `field:"url"` // 页面URL Url string `field:"url"` // 页面URL
NewStatus int32 `field:"newStatus"` // 新状态码 NewStatus int32 `field:"newStatus"` // 新状态码
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
Body string `field:"body"` // 页面内容 Body string `field:"body"` // 页面内容
BodyType string `field:"bodyType"` // 内容类型 BodyType string `field:"bodyType"` // 内容类型
ExceptURLPatterns dbs.JSON `field:"exceptURLPatterns"` // 例外URL
OnlyURLPatterns dbs.JSON `field:"onlyURLPatterns"` // 限制URL
} }
type HTTPPageOperator struct { type HTTPPageOperator struct {
Id interface{} // ID Id any // ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
IsOn interface{} // 是否启用 IsOn any // 是否启用
StatusList interface{} // 状态列表 StatusList any // 状态列表
Url interface{} // 页面URL Url any // 页面URL
NewStatus interface{} // 新状态码 NewStatus any // 新状态码
State interface{} // 状态 State any // 状态
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
Body interface{} // 页面内容 Body any // 页面内容
BodyType interface{} // 内容类型 BodyType any // 内容类型
ExceptURLPatterns any // 例外URL
OnlyURLPatterns any // 限制URL
} }
func NewHTTPPageOperator() *HTTPPageOperator { func NewHTTPPageOperator() *HTTPPageOperator {

View File

@@ -101,7 +101,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, isLocationOrGr
// root // root
if IsNotNull(web.Root) { if IsNotNull(web.Root) {
var rootConfig = &serverconfigs.HTTPRootConfig{} var rootConfig = serverconfigs.NewHTTPRootConfig()
err = json.Unmarshal(web.Root, rootConfig) err = json.Unmarshal(web.Root, rootConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -301,7 +301,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, isLocationOrGr
// 自定义防火墙设置 // 自定义防火墙设置
if firewallRef.FirewallPolicyId > 0 { if firewallRef.FirewallPolicyId > 0 {
firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId, cacheMap) firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, firewallRef.FirewallPolicyId, forNode, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -519,6 +519,14 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, isLocationOrGr
} }
if this.shouldCompose(isLocationOrGroup, forNode, ccConfig.IsPrior, ccConfig.IsOn) { if this.shouldCompose(isLocationOrGroup, forNode, ccConfig.IsPrior, ccConfig.IsOn) {
config.CC = ccConfig config.CC = ccConfig
if forNode {
for index, threshold := range ccConfig.Thresholds {
if index < len(serverconfigs.DefaultHTTPCCThresholds) {
threshold.MergeIfEmpty(serverconfigs.DefaultHTTPCCThresholds[index])
}
}
}
} }
} }
@@ -546,6 +554,18 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, isLocationOrGr
} }
} }
// hls
if IsNotNull(web.Hls) {
var hlsConfig = &serverconfigs.HLSConfig{}
err = json.Unmarshal(web.Hls, hlsConfig)
if err != nil {
return nil, err
}
if this.shouldCompose(isLocationOrGroup, forNode, hlsConfig.IsPrior, true) {
config.HLS = hlsConfig
}
}
if cacheMap != nil { if cacheMap != nil {
cacheMap.Put(cacheKey, config) cacheMap.Put(cacheKey, config)
} }
@@ -568,6 +588,7 @@ func (this *HTTPWebDAO) CreateWeb(tx *dbs.Tx, adminId int64, userId int64, rootJ
var remoteAddrConfig = &serverconfigs.HTTPRemoteAddrConfig{ var remoteAddrConfig = &serverconfigs.HTTPRemoteAddrConfig{
IsOn: true, IsOn: true,
Value: "${rawRemoteAddr}", Value: "${rawRemoteAddr}",
Type: serverconfigs.HTTPRemoteAddrTypeDefault,
} }
remoteAddrConfigJSON, err := json.Marshal(remoteAddrConfig) remoteAddrConfigJSON, err := json.Marshal(remoteAddrConfig)
if err != nil { if err != nil {
@@ -1290,6 +1311,61 @@ func (this *HTTPWebDAO) UpdateWebRequestScripts(tx *dbs.Tx, webId int64, config
return this.NotifyUpdate(tx, webId) return this.NotifyUpdate(tx, webId)
} }
// UpdateWebRequestScriptsAsPassed 设置请求脚本为审核通过
func (this *HTTPWebDAO) UpdateWebRequestScriptsAsPassed(tx *dbs.Tx, webId int64, codeMD5 string) error {
if webId <= 0 || len(codeMD5) == 0 {
return nil
}
configString, err := this.Query(tx).
Pk(webId).
Result("requestScripts").
FindStringCol("")
if err != nil {
return nil
}
var config = &serverconfigs.HTTPRequestScriptsConfig{}
if len(configString) == 0 {
return nil
}
err = json.Unmarshal([]byte(configString), config)
if err != nil {
return err
}
var found bool
for _, group := range config.AllGroups() {
for _, script := range group.Scripts {
if script.AuditingCodeMD5 == codeMD5 {
script.Code = script.AuditingCode
script.AuditingCode = ""
script.AuditingCodeMD5 = ""
found = true
}
}
}
if found {
configJSON, err := json.Marshal(config)
if err != nil {
return err
}
err = this.Query(tx).
Pk(webId).
Set("requestScripts", configJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, webId)
}
return nil
}
// FindWebRequestScripts 查找服务的脚本设置 // FindWebRequestScripts 查找服务的脚本设置
func (this *HTTPWebDAO) FindWebRequestScripts(tx *dbs.Tx, webId int64) (*serverconfigs.HTTPRequestScriptsConfig, error) { func (this *HTTPWebDAO) FindWebRequestScripts(tx *dbs.Tx, webId int64) (*serverconfigs.HTTPRequestScriptsConfig, error) {
configString, err := this.Query(tx). configString, err := this.Query(tx).
@@ -1390,7 +1466,7 @@ func (this *HTTPWebDAO) UpdateWebReferers(tx *dbs.Tx, webId int64, referersConfi
return this.NotifyUpdate(tx, webId) return this.NotifyUpdate(tx, webId)
} }
// FindWebReferers 查找服务的防盗链配置 // FindWebReferers 查找网站的防盗链配置
func (this *HTTPWebDAO) FindWebReferers(tx *dbs.Tx, webId int64) ([]byte, error) { func (this *HTTPWebDAO) FindWebReferers(tx *dbs.Tx, webId int64) ([]byte, error) {
return this.Query(tx). return this.Query(tx).
Pk(webId). Pk(webId).
@@ -1400,6 +1476,10 @@ func (this *HTTPWebDAO) FindWebReferers(tx *dbs.Tx, webId int64) ([]byte, error)
// UpdateWebUserAgent 修改User-Agent设置 // UpdateWebUserAgent 修改User-Agent设置
func (this *HTTPWebDAO) UpdateWebUserAgent(tx *dbs.Tx, webId int64, userAgentConfig *serverconfigs.UserAgentConfig) error { func (this *HTTPWebDAO) UpdateWebUserAgent(tx *dbs.Tx, webId int64, userAgentConfig *serverconfigs.UserAgentConfig) error {
if webId <= 0 {
return errors.New("require 'webId'")
}
if userAgentConfig == nil { if userAgentConfig == nil {
return nil return nil
} }

View File

@@ -41,6 +41,7 @@ const (
HTTPWebField_Referers dbs.FieldName = "referers" // 防盗链设置 HTTPWebField_Referers dbs.FieldName = "referers" // 防盗链设置
HTTPWebField_UserAgent dbs.FieldName = "userAgent" // UserAgent设置 HTTPWebField_UserAgent dbs.FieldName = "userAgent" // UserAgent设置
HTTPWebField_Optimization dbs.FieldName = "optimization" // 页面优化配置 HTTPWebField_Optimization dbs.FieldName = "optimization" // 页面优化配置
HTTPWebField_Hls dbs.FieldName = "hls" // HLS设置
) )
// HTTPWeb HTTP Web // HTTPWeb HTTP Web
@@ -83,6 +84,7 @@ type HTTPWeb struct {
Referers dbs.JSON `field:"referers"` // 防盗链设置 Referers dbs.JSON `field:"referers"` // 防盗链设置
UserAgent dbs.JSON `field:"userAgent"` // UserAgent设置 UserAgent dbs.JSON `field:"userAgent"` // UserAgent设置
Optimization dbs.JSON `field:"optimization"` // 页面优化配置 Optimization dbs.JSON `field:"optimization"` // 页面优化配置
Hls dbs.JSON `field:"hls"` // HLS设置
} }
type HTTPWebOperator struct { type HTTPWebOperator struct {
@@ -124,6 +126,7 @@ type HTTPWebOperator struct {
Referers any // 防盗链设置 Referers any // 防盗链设置
UserAgent any // UserAgent设置 UserAgent any // UserAgent设置
Optimization any // 页面优化配置 Optimization any // 页面优化配置
Hls any // HLS设置
} }
func NewHTTPWebOperator() *HTTPWebOperator { func NewHTTPWebOperator() *HTTPWebOperator {

View File

@@ -5,7 +5,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman" "github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@@ -13,7 +13,8 @@ import (
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"math" "net"
"strings"
"time" "time"
) )
@@ -155,6 +156,59 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri
return nil return nil
} }
// DisableIPItemsWithIPValue 禁用某个IP相关条目
func (this *IPItemDAO) DisableIPItemsWithIPValue(tx *dbs.Tx, value string, sourceUserId int64, listId int64) error {
if len(value) == 0 {
return errors.New("invalid 'value'")
}
var query = this.Query(tx).
Result("id", "listId").
Attr("value", value).
State(IPItemStateEnabled)
if listId > 0 {
query.Attr("listId", listId)
}
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
}
ones, err := query.FindAll()
if err != nil {
return err
}
var itemIds = []int64{}
for _, one := range ones {
var item = one.(*IPItem)
var itemId = int64(item.Id)
itemIds = append(itemIds, itemId)
}
for _, itemId := range itemIds {
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(itemId).
Set("state", IPItemStateDisabled).
Set("version", version).
Update()
if err != nil {
return err
}
}
if len(itemIds) > 0 {
return this.NotifyUpdate(tx, itemIds[len(itemIds)-1])
}
return nil
}
// DisableIPItemsWithListId 禁用某个IP名单内的所有IP // DisableIPItemsWithListId 禁用某个IP名单内的所有IP
func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error { func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error {
for { for {
@@ -236,9 +290,46 @@ func (this *IPItemDAO) DeleteOldItem(tx *dbs.Tx, listId int64, ipFrom string, ip
return nil return nil
} }
// DeleteOldItemWithValue 根据IP删除以前的旧记录
func (this *IPItemDAO) DeleteOldItemWithValue(tx *dbs.Tx, listId int64, value string) error {
if len(value) == 0 {
return nil
}
ones, err := this.Query(tx).
ResultPk().
UseIndex("ipFrom").
Attr("listId", listId).
Attr("value", value).
Attr("state", IPItemStateEnabled).
FindAll()
if err != nil {
return err
}
for _, one := range ones {
var itemId = int64(one.(*IPItem).Id)
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
err = this.Query(tx).
Pk(itemId).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// CreateIPItem 创建IP // CreateIPItem 创建IP
func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx,
listId int64, listId int64,
value string,
ipFrom string, ipFrom string,
ipTo string, ipTo string,
expiredAt int64, expiredAt int64,
@@ -253,6 +344,15 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx,
sourceHTTPFirewallRuleGroupId int64, sourceHTTPFirewallRuleGroupId int64,
sourceHTTPFirewallRuleSetId int64, sourceHTTPFirewallRuleSetId int64,
shouldNotify bool) (int64, error) { shouldNotify bool) (int64, error) {
// generate 'itemType'
if itemType != IPItemTypeAll && len(ipFrom) > 0 {
if iputils.IsIPv4(ipFrom) {
itemType = IPItemTypeIPv4
} else if iputils.IsIPv6(ipFrom) {
itemType = IPItemTypeIPv6
}
}
version, err := SharedIPListDAO.IncreaseVersion(tx) version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -260,10 +360,10 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx,
var op = NewIPItemOperator() var op = NewIPItemOperator()
op.ListId = listId op.ListId = listId
op.Value = value
op.IpFrom = ipFrom op.IpFrom = ipFrom
op.IpTo = ipTo op.IpTo = ipTo
op.IpFromLong = utils.IP2Long(ipFrom)
op.IpToLong = utils.IP2Long(ipTo)
op.Reason = reason op.Reason = reason
op.Type = itemType op.Type = itemType
op.EventLevel = eventLevel op.EventLevel = eventLevel
@@ -319,11 +419,20 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx,
} }
// UpdateIPItem 修改IP // UpdateIPItem 修改IP
func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string, itemType IPItemType, eventLevel string) error { func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, value string, ipFrom string, ipTo string, expiredAt int64, reason string, itemType IPItemType, eventLevel string) error {
if itemId <= 0 { if itemId <= 0 {
return errors.New("invalid itemId") return errors.New("invalid itemId")
} }
// generate 'itemType'
if itemType != IPItemTypeAll && len(ipFrom) > 0 {
if iputils.IsIPv4(ipFrom) {
itemType = IPItemTypeIPv4
} else if iputils.IsIPv6(ipFrom) {
itemType = IPItemTypeIPv6
}
}
listId, err := this.Query(tx). listId, err := this.Query(tx).
Pk(itemId). Pk(itemId).
Result("listId"). Result("listId").
@@ -342,10 +451,10 @@ func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipT
var op = NewIPItemOperator() var op = NewIPItemOperator()
op.Id = itemId op.Id = itemId
op.Value = value
op.IpFrom = ipFrom op.IpFrom = ipFrom
op.IpTo = ipTo op.IpTo = ipTo
op.IpFromLong = utils.IP2Long(ipFrom)
op.IpToLong = utils.IP2Long(ipTo)
op.Reason = reason op.Reason = reason
op.Type = itemType op.Type = itemType
op.EventLevel = eventLevel op.EventLevel = eventLevel
@@ -442,16 +551,21 @@ func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) {
} }
// FindEnabledItemContainsIP 查找包含某个IP的Item // FindEnabledItemContainsIP 查找包含某个IP的Item
func (this *IPItemDAO) FindEnabledItemContainsIP(tx *dbs.Tx, listId int64, ip uint64) (*IPItem, error) { func (this *IPItemDAO) FindEnabledItemContainsIP(tx *dbs.Tx, listId int64, ip string) (*IPItem, error) {
query := this.Query(tx). var query = this.Query(tx).
Attr("listId", listId). Attr("listId", listId).
State(IPItemStateEnabled) State(IPItemStateEnabled)
if ip > math.MaxUint32 {
query.Where("(type='all' OR ipFromLong=:ip)") if iputils.IsIPv4(ip) {
} else { query.Where("(type='all' OR ipFrom =:ip OR INET_ATON(:ip) BETWEEN INET_ATON(ipFrom) AND INET_ATON(ipTo))").
query.Where("(type='all' OR ipFromLong=:ip OR (ipToLong>0 AND ipFromLong<=:ip AND ipToLong>=:ip))").
Param("ip", ip) Param("ip", ip)
} else if iputils.IsIPv6(ip) {
query.Where("(type='all' OR ipFrom =:ip OR HEX(INET6_ATON(:ip)) BETWEEN HEX(INET6_ATON(ipFrom)) AND HEX(INET6_ATON(ipTo)))").
Param("ip", ip)
} else {
return nil, nil
} }
one, err := query.Find() one, err := query.Find()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -498,7 +612,17 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, ke
} }
} }
if len(keyword) > 0 { if len(keyword) > 0 {
query.Like("ipFrom", dbutils.QuoteLike(keyword)) if net.ParseIP(keyword) != nil { // 是一个IP地址
if iputils.IsIPv4(keyword) {
query.Where("(type='all' OR ipFrom =:ipKeyword OR INET_ATON(:ipKeyword) BETWEEN INET_ATON(ipFrom) AND INET_ATON(ipTo))").
Param("ipKeyword", keyword)
} else if iputils.IsIPv6(keyword) {
query.Where("(type='all' OR ipFrom =:ipKeyword OR HEX(INET6_ATON(:ipKeyword)) BETWEEN HEX(INET6_ATON(ipFrom)) AND HEX(INET6_ATON(ipTo)))").
Param("ipKeyword", keyword)
}
} else {
query.Like("ipFrom", dbutils.QuoteLike(keyword))
}
} }
if len(ip) > 0 { if len(ip) > 0 {
query.Attr("ipFrom", ip) query.Attr("ipFrom", ip)
@@ -540,7 +664,17 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, key
} }
} }
if len(keyword) > 0 { if len(keyword) > 0 {
query.Like("ipFrom", dbutils.QuoteLike(keyword)) if net.ParseIP(keyword) != nil { // 是一个IP地址
if iputils.IsIPv4(keyword) {
query.Where("(type='all' OR ipFrom =:ipKeyword OR INET_ATON(:ipKeyword) BETWEEN INET_ATON(ipFrom) AND INET_ATON(ipTo))").
Param("ipKeyword", keyword)
} else if iputils.IsIPv6(keyword) {
query.Where("(type='all' OR ipFrom =:ipKeyword OR HEX(INET6_ATON(:ipKeyword)) BETWEEN HEX(INET6_ATON(ipFrom)) AND HEX(INET6_ATON(ipTo)))").
Param("ipKeyword", keyword)
}
} else {
query.Like("ipFrom", dbutils.QuoteLike(keyword))
}
} }
if len(ip) > 0 { if len(ip) > 0 {
query.Attr("ipFrom", ip) query.Attr("ipFrom", ip)
@@ -573,6 +707,62 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, key
return return
} }
// ListAllIPItemIds 搜索所有IP Id列表
func (this *IPItemDAO) ListAllIPItemIds(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (itemIds []int64, err error) {
var query = this.Query(tx)
if sourceUserId > 0 {
if listId <= 0 {
query.Where("((listId=" + types.String(firewallconfigs.GlobalListId) + " AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))")
query.Param("sourceUserId", sourceUserId)
} else if listId == firewallconfigs.GlobalListId {
query.Attr("sourceUserId", sourceUserId)
query.UseIndex("sourceUserId")
}
}
if len(keyword) > 0 {
if net.ParseIP(keyword) != nil { // 是一个IP地址
query.Attr("ipFrom", keyword)
} else {
query.Like("ipFrom", dbutils.QuoteLike(keyword))
}
}
if len(ip) > 0 {
query.Attr("ipFrom", ip)
}
if listId > 0 {
query.Attr("listId", listId)
} else {
if len(listType) > 0 {
query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))")
query.Param("listType", listType)
} else {
query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))")
}
}
if unread {
query.Attr("isRead", 0)
}
if len(eventLevel) > 0 {
query.Attr("eventLevel", eventLevel)
}
result, err := query.
ResultPk().
State(IPItemStateEnabled).
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
Param("expiredAt", time.Now().Unix()).
DescPk().
Offset(offset).
Size(size).
FindAll()
if err != nil {
return nil, err
}
for _, itemOne := range result {
itemIds = append(itemIds, int64(itemOne.(*IPItem).Id))
}
return
}
// UpdateItemsRead 设置所有未已读 // UpdateItemsRead 设置所有未已读
func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx, sourceUserId int64) error { func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx, sourceUserId int64) error {
var query = this.Query(tx). var query = this.Query(tx).
@@ -632,6 +822,60 @@ func (this *IPItemDAO) CleanExpiredIPItems(tx *dbs.Tx) error {
return nil return nil
} }
// ParseIPValue 解析IP值
func (this *IPItemDAO) ParseIPValue(value string) (newValue string, ipFrom string, ipTo string, ok bool) {
if len(value) == 0 {
return
}
newValue = value
// ip1-ip2
if strings.Contains(value, "-") {
var pieces = strings.Split(value, "-")
if len(pieces) != 2 {
return
}
ipFrom = strings.TrimSpace(pieces[0])
ipTo = strings.TrimSpace(pieces[1])
if !iputils.IsValid(ipFrom) || !iputils.IsValid(ipTo) {
return
}
if !iputils.IsSameVersion(ipFrom, ipTo) {
return
}
if iputils.CompareIP(ipFrom, ipTo) > 0 {
ipFrom, ipTo = ipTo, ipFrom
newValue = ipFrom + "-" + ipTo
}
ok = true
return
}
// ip/mask
if strings.Contains(value, "/") {
cidr, err := iputils.ParseCIDR(value)
if err != nil {
return
}
return newValue, cidr.From().String(), cidr.To().String(), true
}
// single value
if iputils.IsValid(value) {
ipFrom = value
ok = true
return
}
return
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error { func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error {
// 获取ListId // 获取ListId
@@ -665,6 +909,9 @@ func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error {
} }
} else { } else {
clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIds(tx) clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIds(tx)
if err != nil {
return err
}
for _, clusterId := range clusterIds { for _, clusterId := range clusterIds {
err = SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeIPItemChanged) err = SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeIPItemChanged)
if err != nil { if err != nil {

View File

@@ -51,7 +51,8 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) {
var dao = models.NewIPItemDAO() var dao = models.NewIPItemDAO()
var n = 10 var n = 10
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0, false) var ip = "192." + types.String(rands.Int(0, 255)) + "." + types.String(rands.Int(0, 255)) + "." + types.String(rands.Int(0, 255))
itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, ip, ip, "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -74,3 +75,16 @@ func TestIPItemDAO_DisableIPItemsWithIP(t *testing.T) {
} }
t.Log("ok") t.Log("ok")
} }
func TestIPItemDAO_ParseIPValue(t *testing.T) {
var dao = models.NewIPItemDAO()
t.Log(dao.ParseIPValue("192.168.1.100"))
t.Log(dao.ParseIPValue("192.168.1.100-192.168.1.200"))
t.Log(dao.ParseIPValue("192.168.1.200-192.168.1.100"))
t.Log(dao.ParseIPValue("192.168.1.100/24"))
t.Log(dao.ParseIPValue("::1"))
t.Log(dao.ParseIPValue("192.168.1.100-::2"))
t.Log(dao.ParseIPValue("192"))
t.Log(dao.ParseIPValue("192.168.1.200/256"))
t.Log(dao.ParseIPValue("192.168.1.200-"))
}

View File

@@ -1,14 +1,44 @@
package models package models
import "github.com/iwind/TeaGo/dbs"
const (
IPItemField_Id dbs.FieldName = "id" // ID
IPItemField_ListId dbs.FieldName = "listId" // 所属名单ID
IPItemField_Value dbs.FieldName = "value" // 原始值
IPItemField_Type dbs.FieldName = "type" // 类型
IPItemField_IpFrom dbs.FieldName = "ipFrom" // 开始IP
IPItemField_IpTo dbs.FieldName = "ipTo" // 结束IP
IPItemField_IpFromLong dbs.FieldName = "ipFromLong" // 开始IP整型弃用
IPItemField_IpToLong dbs.FieldName = "ipToLong" // 结束IP整型弃用
IPItemField_Version dbs.FieldName = "version" // 版本
IPItemField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
IPItemField_UpdatedAt dbs.FieldName = "updatedAt" // 修改时间
IPItemField_Reason dbs.FieldName = "reason" // 加入说明
IPItemField_EventLevel dbs.FieldName = "eventLevel" // 事件级别
IPItemField_State dbs.FieldName = "state" // 状态
IPItemField_ExpiredAt dbs.FieldName = "expiredAt" // 过期时间
IPItemField_ServerId dbs.FieldName = "serverId" // 有效范围服务ID
IPItemField_NodeId dbs.FieldName = "nodeId" // 有效范围节点ID
IPItemField_SourceNodeId dbs.FieldName = "sourceNodeId" // 来源节点ID
IPItemField_SourceServerId dbs.FieldName = "sourceServerId" // 来源服务ID
IPItemField_SourceHTTPFirewallPolicyId dbs.FieldName = "sourceHTTPFirewallPolicyId" // 来源策略ID
IPItemField_SourceHTTPFirewallRuleGroupId dbs.FieldName = "sourceHTTPFirewallRuleGroupId" // 来源规则集分组ID
IPItemField_SourceHTTPFirewallRuleSetId dbs.FieldName = "sourceHTTPFirewallRuleSetId" // 来源规则集ID
IPItemField_SourceUserId dbs.FieldName = "sourceUserId" // 用户ID
IPItemField_IsRead dbs.FieldName = "isRead" // 是否已读
)
// IPItem IP // IPItem IP
type IPItem struct { type IPItem struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
ListId uint32 `field:"listId"` // 所属名单ID ListId uint32 `field:"listId"` // 所属名单ID
Value string `field:"value"` // 原始值
Type string `field:"type"` // 类型 Type string `field:"type"` // 类型
IpFrom string `field:"ipFrom"` // 开始IP IpFrom string `field:"ipFrom"` // 开始IP
IpTo string `field:"ipTo"` // 结束IP IpTo string `field:"ipTo"` // 结束IP
IpFromLong uint64 `field:"ipFromLong"` // 开始IP整型 IpFromLong uint64 `field:"ipFromLong"` // 开始IP整型(弃用)
IpToLong uint64 `field:"ipToLong"` // 结束IP整型 IpToLong uint64 `field:"ipToLong"` // 结束IP整型(弃用)
Version uint64 `field:"version"` // 版本 Version uint64 `field:"version"` // 版本
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
UpdatedAt uint64 `field:"updatedAt"` // 修改时间 UpdatedAt uint64 `field:"updatedAt"` // 修改时间
@@ -30,11 +60,12 @@ type IPItem struct {
type IPItemOperator struct { type IPItemOperator struct {
Id any // ID Id any // ID
ListId any // 所属名单ID ListId any // 所属名单ID
Value any // 原始值
Type any // 类型 Type any // 类型
IpFrom any // 开始IP IpFrom any // 开始IP
IpTo any // 结束IP IpTo any // 结束IP
IpFromLong any // 开始IP整型 IpFromLong any // 开始IP整型(弃用)
IpToLong any // 结束IP整型 IpToLong any // 结束IP整型(弃用)
Version any // 版本 Version any // 版本
CreatedAt any // 创建时间 CreatedAt any // 创建时间
UpdatedAt any // 修改时间 UpdatedAt any // 修改时间

View File

@@ -1 +1,15 @@
package models package models
// ComposeValue 组合原始值
func (this *IPItem) ComposeValue() string {
if len(this.Value) > 0 {
return this.Value
}
// 兼容以往版本
if len(this.IpTo) > 0 {
return this.IpFrom + "-" + this.IpTo
}
return this.IpFrom
}

View File

@@ -3,6 +3,7 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/regions" "github.com/TeaOSLab/EdgeAPI/internal/db/models/regions"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary" "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
@@ -12,6 +13,7 @@ import (
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"io" "io"
"os" "os"
"strings"
"time" "time"
) )
@@ -299,7 +301,7 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
var libraryFile = one.(*IPLibraryFile) var libraryFile = one.(*IPLibraryFile)
template, err := iplibrary.NewTemplate(libraryFile.Template) template, err := iplibrary.NewTemplate(libraryFile.Template)
if err != nil { if err != nil {
return errors.New("create template from '" + libraryFile.Template + "' failed: " + err.Error()) return fmt.Errorf("create template from '%s' failed: %w", libraryFile.Template, err)
} }
var fileId = int64(libraryFile.FileId) var fileId = int64(libraryFile.FileId)
@@ -314,17 +316,17 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
if os.IsNotExist(err) { if os.IsNotExist(err) {
err = os.Mkdir(dir, 0777) err = os.Mkdir(dir, 0777)
if err != nil { if err != nil {
return errors.New("can not open dir '" + dir + "' to write: " + err.Error()) return fmt.Errorf("can not open dir '%s' to write: %w", dir, err)
} }
} else { } else {
return errors.New("can not open dir '" + dir + "' to write: " + err.Error()) return fmt.Errorf("can not open dir '%s' to write: %w", dir, err)
} }
} else if !stat.IsDir() { } else if !stat.IsDir() {
_ = os.Remove(dir) _ = os.Remove(dir)
err = os.Mkdir(dir, 0777) err = os.Mkdir(dir, 0777)
if err != nil { if err != nil {
return errors.New("can not open dir '" + dir + "' to write: " + err.Error()) return fmt.Errorf("can not open dir '%s' to write: %w", dir, err)
} }
} }
@@ -428,7 +430,7 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
err = writer.WriteMeta() err = writer.WriteMeta()
if err != nil { if err != nil {
return errors.New("write meta failed: " + err.Error()) return fmt.Errorf("write meta failed: %w", err)
} }
chunkIds, err := SharedFileChunkDAO.FindAllFileChunkIds(tx, fileId) chunkIds, err := SharedFileChunkDAO.FindAllFileChunkIds(tx, fileId)
@@ -448,6 +450,14 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
for _, province := range dbProvinces { for _, province := range dbProvinces {
for _, code := range province.AllCodes() { for _, code := range province.AllCodes() {
provinceMap[types.String(province.CountryId)+"_"+code] = int64(province.ValueId) provinceMap[types.String(province.CountryId)+"_"+code] = int64(province.ValueId)
for _, suffix := range regions.RegionProvinceSuffixes {
if strings.HasSuffix(code, suffix) {
provinceMap[types.String(province.CountryId)+"_"+strings.TrimSuffix(code, suffix)] = int64(province.ValueId)
} else {
provinceMap[types.String(province.CountryId)+"_"+(code+suffix)] = int64(province.ValueId)
}
}
} }
} }
@@ -503,7 +513,7 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
err = writer.Write(ipFrom, ipTo, countryId, provinceId, cityId, townId, providerId) err = writer.Write(ipFrom, ipTo, countryId, provinceId, cityId, townId, providerId)
if err != nil { if err != nil {
return errors.New("write failed: " + err.Error()) return fmt.Errorf("write failed: %w", err)
} }
return nil return nil
@@ -536,7 +546,7 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
// 将生成的内容写入到文件 // 将生成的内容写入到文件
stat, err = os.Stat(filePath) stat, err = os.Stat(filePath)
if err != nil { if err != nil {
return errors.New("stat generated file failed: " + err.Error()) return fmt.Errorf("stat generated file failed: %w", err)
} }
generatedFileId, err := SharedFileDAO.CreateFile(tx, 0, 0, "ipLibraryFile", "", libraryCode+".db", stat.Size(), "", false) generatedFileId, err := SharedFileDAO.CreateFile(tx, 0, 0, "ipLibraryFile", "", libraryCode+".db", stat.Size(), "", false)
if err != nil { if err != nil {
@@ -545,7 +555,7 @@ func (this *IPLibraryFileDAO) GenerateIPLibrary(tx *dbs.Tx, libraryFileId int64)
fp, err := os.Open(filePath) fp, err := os.Open(filePath)
if err != nil { if err != nil {
return errors.New("open generated file failed: " + err.Error()) return fmt.Errorf("open generated file failed: %w", err)
} }
var buf = make([]byte, 256*1024) var buf = make([]byte, 256*1024)
for { for {

View File

@@ -11,6 +11,7 @@ import (
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -22,7 +23,7 @@ const (
var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList
var DefaultGlobalIPList = &IPList{ var DefaultGlobalIPList = &IPList{
Id: uint32(firewallconfigs.GlobalListId), Id: uint32(firewallconfigs.GlobalListId),
Name: "全局封锁名单", Name: "系统黑名单",
IsPublic: true, IsPublic: true,
IsGlobal: true, IsGlobal: true,
Type: "black", Type: "black",
@@ -61,12 +62,16 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error {
} }
// DisableIPList 禁用条目 // DisableIPList 禁用条目
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error { func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(listId).
Set("state", IPListStateDisabled). Set("state", IPListStateDisabled).
Update() Update()
return err if err != nil {
return err
}
return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON()))
} }
// FindEnabledIPList 查找启用中的条目 // FindEnabledIPList 查找启用中的条目
@@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error { func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
// WAF策略中的
httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
if err != nil { if err != nil {
return err return err
} }
resultClusterIds := []int64{}
// 规则集动作中使用此名单的策略
ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId)
if err != nil {
return err
}
for _, ruleSetId := range ruleSetIds {
ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId)
if err != nil {
return err
}
if ruleGroupId > 0 {
policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId)
if err != nil {
return err
}
if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) {
httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId)
}
}
}
// 查找集群
var resultClusterIds = []int64{}
for _, policyId := range httpFirewallPolicyIds { for _, policyId := range httpFirewallPolicyIds {
// 集群 // 集群
clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
@@ -310,3 +339,16 @@ func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskT
return nil return nil
} }
// FindServerIdWithListId 查找IP名单对应的网站ID
func (this *IPListDAO) FindServerIdWithListId(tx *dbs.Tx, listId int64) (serverId int64, err error) {
if listId <= 0 {
return
}
serverId, err = this.Query(tx).
Pk(listId).
Result("serverId").
FindInt64Col(0)
return
}

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"errors"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"runtime" "runtime"
@@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 100) err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 85) err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 17) err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
} }
} }
func TestIPListDAO_NotifyUpdate(t *testing.T) {
dbs.NotifyReady()
var dao = NewIPListDAO()
var tx *dbs.Tx
err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted)
if err != nil {
t.Fatal(err)
}
}
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
@@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
_, _ = dao.IncreaseVersion(tx) _, _ = dao.IncreaseVersion(tx)
} }
} }

View File

@@ -135,40 +135,16 @@ func (this *LoginSessionDAO) WriteSessionValue(tx *dbs.Tx, sid string, key strin
sessionOp.UserId = userId sessionOp.UserId = userId
if isNewSession { if isNewSession {
// 删除此用户之前创建的SESSION防止单个用户SESSION过多 // 删除此用户之前创建的SESSION不再保存以往的SESSION避免安全问题
// TODO 将来改成按照活跃时间排序 err = this.Query(tx).
const maxSessionsPerUser = 10
oldOnes, err := this.Query(tx).
ResultPk(). ResultPk().
Attr("adminId", adminId). Attr("adminId", adminId).
Attr("userId", userId). Attr("userId", userId).
Asc("createdAt"). Neq("sid", sid).
FindAll() DeleteQuickly()
if err != nil { if err != nil {
return err return err
} }
var countOldOnes = len(oldOnes)
if countOldOnes > maxSessionsPerUser {
var countDeleted int
for _, oldOne := range oldOnes {
var oldSessionId = int64(oldOne.(*LoginSession).Id)
if oldSessionId == sessionId {
continue
}
if countDeleted < countOldOnes-maxSessionsPerUser {
err = this.Query(tx).
Pk(oldSessionId).
DeleteQuickly()
if err != nil {
return err
}
countDeleted++
} else {
break
}
}
}
} }
} }

View File

@@ -27,6 +27,8 @@ const (
type MessageType = string type MessageType = string
const ( const (
MessageTypeAll MessageType = "*"
// 这里的命名问题(首字母大写)为历史遗留问题,暂不修改 // 这里的命名问题(首字母大写)为历史遗留问题,暂不修改
MessageTypeHealthCheckFailed MessageType = "HealthCheckFailed" // 节点健康检查失败 MessageTypeHealthCheckFailed MessageType = "HealthCheckFailed" // 节点健康检查失败
@@ -109,14 +111,17 @@ func (this *MessageDAO) FindEnabledMessage(tx *dbs.Tx, id int64) (*Message, erro
} }
// CreateClusterMessage 创建集群消息 // CreateClusterMessage 创建集群消息
func (this *MessageDAO) CreateClusterMessage(tx *dbs.Tx, role string, clusterId int64, messageType MessageType, level string, subject string, body string, paramsJSON []byte) error { func (this *MessageDAO) CreateClusterMessage(tx *dbs.Tx, role string, clusterId int64, messageType MessageType, level string, subject string, shortBody string, body string, paramsJSON []byte) error {
_, err := this.createMessage(tx, role, clusterId, 0, messageType, level, subject, body, paramsJSON) if len(shortBody) == 0 {
shortBody = body
}
_, err := this.createMessage(tx, role, clusterId, 0, messageType, level, subject, shortBody, paramsJSON)
if err != nil { if err != nil {
return err return err
} }
// 发送给媒介接收人 // 发送给媒介接收人
err = SharedMessageTaskDAO.CreateMessageTasks(tx, role, 0, 0, 0, messageType, subject, body) err = SharedMessageTaskDAO.CreateMessageTasks(tx, role, clusterId, 0, 0, messageType, subject, body)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -12,7 +12,7 @@ func TestMessageDAO_CreateClusterMessage(t *testing.T) {
var tx *dbs.Tx var tx *dbs.Tx
dao := NewMessageDAO() dao := NewMessageDAO()
err := dao.CreateClusterMessage(tx, nodeconfigs.NodeRoleNode, 1, "test", "error", "123", "123", []byte("456")) err := dao.CreateClusterMessage(tx, nodeconfigs.NodeRoleNode, 1, "test", "error", "123", "123", "123", []byte("456"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -4,8 +4,6 @@ import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
) )
const ( const (
@@ -34,7 +32,7 @@ func init() {
}) })
} }
// 启用条目 // EnableMessageMedia 启用条目
func (this *MessageMediaDAO) EnableMessageMedia(tx *dbs.Tx, id int64) error { func (this *MessageMediaDAO) EnableMessageMedia(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -43,7 +41,7 @@ func (this *MessageMediaDAO) EnableMessageMedia(tx *dbs.Tx, id int64) error {
return err return err
} }
// 禁用条目 // DisableMessageMedia 禁用条目
func (this *MessageMediaDAO) DisableMessageMedia(tx *dbs.Tx, id int64) error { func (this *MessageMediaDAO) DisableMessageMedia(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(id).
@@ -52,7 +50,7 @@ func (this *MessageMediaDAO) DisableMessageMedia(tx *dbs.Tx, id int64) error {
return err return err
} }
// 查找启用中的条目 // FindEnabledMessageMedia 查找启用中的条目
func (this *MessageMediaDAO) FindEnabledMessageMedia(tx *dbs.Tx, id int64) (*MessageMedia, error) { func (this *MessageMediaDAO) FindEnabledMessageMedia(tx *dbs.Tx, id int64) (*MessageMedia, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(id).
@@ -64,7 +62,7 @@ func (this *MessageMediaDAO) FindEnabledMessageMedia(tx *dbs.Tx, id int64) (*Mes
return result.(*MessageMedia), err return result.(*MessageMedia), err
} }
// 根据主键查找名称 // FindMessageMediaName 根据主键查找名称
func (this *MessageMediaDAO) FindMessageMediaName(tx *dbs.Tx, id int64) (string, error) { func (this *MessageMediaDAO) FindMessageMediaName(tx *dbs.Tx, id int64) (string, error) {
return this.Query(tx). return this.Query(tx).
Pk(id). Pk(id).
@@ -72,7 +70,7 @@ func (this *MessageMediaDAO) FindMessageMediaName(tx *dbs.Tx, id int64) (string,
FindStringCol("") FindStringCol("")
} }
// 查询所有可用媒介 // FindAllEnabledMessageMedias 查询所有可用媒介
func (this *MessageMediaDAO) FindAllEnabledMessageMedias(tx *dbs.Tx) (result []*MessageMedia, err error) { func (this *MessageMediaDAO) FindAllEnabledMessageMedias(tx *dbs.Tx) (result []*MessageMedia, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
State(MessageMediaStateEnabled). State(MessageMediaStateEnabled).
@@ -82,74 +80,3 @@ func (this *MessageMediaDAO) FindAllEnabledMessageMedias(tx *dbs.Tx) (result []*
FindAll() FindAll()
return return
} }
// 设置当前所有可用的媒介
func (this *MessageMediaDAO) UpdateMessageMedias(tx *dbs.Tx, mediaMaps []maps.Map) error {
// 新的媒介信息
mediaTypes := []string{}
for index, m := range mediaMaps {
order := len(mediaMaps) - index
mediaType := m.GetString("type")
mediaTypes = append(mediaTypes, mediaType)
name := m.GetString("name")
description := m.GetString("description")
userDescription := m.GetString("userDescription")
isOn := m.GetBool("isOn")
mediaId, err := this.Query(tx).
ResultPk().
Attr("type", mediaType).
FindInt64Col(0)
if err != nil {
return err
}
var op = NewMessageMediaOperator()
if mediaId > 0 {
op.Id = mediaId
}
op.Name = name
op.Type = mediaType
op.Description = description
op.UserDescription = userDescription
op.IsOn = isOn
op.Order = order
op.State = MessageMediaStateEnabled
err = this.Save(tx, op)
if err != nil {
return err
}
}
// 老的媒介信息
ones, err := this.Query(tx).
FindAll()
if err != nil {
return err
}
for _, one := range ones {
mediaType := one.(*MessageMedia).Type
if !lists.ContainsString(mediaTypes, mediaType) {
err := this.Query(tx).
Pk(one.(*MessageMedia).Id).
Set("state", MessageMediaStateDisabled).
UpdateQuickly()
if err != nil {
return err
}
}
}
return nil
}
// 根据类型查找媒介
func (this *MessageMediaDAO) FindEnabledMediaWithType(tx *dbs.Tx, mediaType string) (*MessageMedia, error) {
one, err := this.Query(tx).
Attr("type", mediaType).
State(MessageMediaStateEnabled).
Find()
if one == nil || err != nil {
return nil, err
}
return one.(*MessageMedia), nil
}

View File

@@ -98,24 +98,6 @@ func (this *MessageReceiverDAO) CreateReceiver(tx *dbs.Tx, role string, clusterI
return this.SaveInt64(tx, op) return this.SaveInt64(tx, op)
} }
// FindAllEnabledReceivers 查询接收人
func (this *MessageReceiverDAO) FindAllEnabledReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (result []*MessageReceiver, err error) {
query := this.Query(tx)
if len(messageType) > 0 {
query.Attr("type", []string{"*", messageType}) // *表示所有的
}
_, err = query.
Attr("role", role).
Attr("clusterId", clusterId).
Attr("nodeId", nodeId).
Attr("serverId", serverId).
State(MessageReceiverStateEnabled).
AscPk().
Slice(&result).
FindAll()
return
}
// CountAllEnabledReceivers 计算接收人数量 // CountAllEnabledReceivers 计算接收人数量
func (this *MessageReceiverDAO) CountAllEnabledReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (int64, error) { func (this *MessageReceiverDAO) CountAllEnabledReceivers(tx *dbs.Tx, role string, clusterId int64, nodeId int64, serverId int64, messageType string) (int64, error) {
query := this.Query(tx) query := this.Query(tx)
@@ -146,6 +128,8 @@ func (this *MessageReceiverDAO) FindEnabledBestFitReceivers(tx *dbs.Tx, role str
} else if nodeId > 0 { } else if nodeId > 0 {
query.Attr("nodeId", nodeId) query.Attr("nodeId", nodeId)
} else if clusterId > 0 { } else if clusterId > 0 {
query.Attr("serverId", 0)
query.Attr("nodeId", 0)
query.Attr("clusterId", clusterId) query.Attr("clusterId", clusterId)
} }
_, err = query. _, err = query.

View File

@@ -1,30 +0,0 @@
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestMessageReceiverDAO_FindEnabledBestFitReceivers(t *testing.T) {
var tx *dbs.Tx
{
receivers, err := NewMessageReceiverDAO().FindEnabledBestFitReceivers(tx, nodeconfigs.NodeRoleNode, 18, 1, 2, "*")
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(receivers, t)
}
{
receivers, err := NewMessageReceiverDAO().FindEnabledBestFitReceivers(tx, nodeconfigs.NodeRoleNode, 30, 1, 2, "*")
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(receivers, t)
}
}

View File

@@ -1,31 +1,19 @@
package models package models
import ( import (
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman" "github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"time" "time"
) )
type MessageTaskStatus = int
const ( const (
MessageTaskStateEnabled = 1 // 已启用 MessageTaskStateEnabled = 1 // 已启用
MessageTaskStateDisabled = 0 // 已禁用 MessageTaskStateDisabled = 0 // 已禁用
MessageTaskStatusNone MessageTaskStatus = 0 // 普通状态
MessageTaskStatusSending MessageTaskStatus = 1 // 发送中
MessageTaskStatusSuccess MessageTaskStatus = 2 // 发送成功
MessageTaskStatusFailed MessageTaskStatus = 3 // 发送失败
) )
type MessageTaskDAO dbs.DAO type MessageTaskDAO dbs.DAO
@@ -94,151 +82,6 @@ func (this *MessageTaskDAO) FindEnabledMessageTask(tx *dbs.Tx, id int64) (*Messa
return result.(*MessageTask), err return result.(*MessageTask), err
} }
// CreateMessageTask 创建任务
func (this *MessageTaskDAO) CreateMessageTask(tx *dbs.Tx, recipientId int64, instanceId int64, user string, subject string, body string, isPrimary bool) (int64, error) {
if !teaconst.IsPlus {
return 0, nil
}
var hash = stringutil.Md5(types.String(recipientId) + "@" + types.String(instanceId) + "@" + user + "@" + subject + "@" + types.String(isPrimary))
recipientInstanceId, err := SharedMessageRecipientDAO.FindRecipientInstanceId(tx, recipientId)
if err != nil {
return 0, err
}
if recipientInstanceId > 0 {
hashLifeSeconds, err := SharedMessageMediaInstanceDAO.FindInstanceHashLifeSeconds(tx, recipientInstanceId)
if err != nil {
return 0, err
}
if hashLifeSeconds >= 0 { // 意味着此值如果小于0则不做判断
lastMessageAt, err := this.Query(tx).
Attr("hash", hash).
Result("createdAt").
DescPk().
FindInt64Col(0)
if err != nil {
return 0, err
}
// 对于同一个人N分钟内消息不重复发送
if hashLifeSeconds <= 0 {
hashLifeSeconds = 60
}
if lastMessageAt > 0 && time.Now().Unix()-lastMessageAt < int64(hashLifeSeconds) {
return 0, nil
}
}
}
var op = NewMessageTaskOperator()
op.RecipientId = recipientId
op.InstanceId = instanceId
op.Hash = hash
op.User = user
op.Subject = subject
op.Body = body
op.IsPrimary = isPrimary
op.Day = timeutil.Format("Ymd")
op.Status = MessageTaskStatusNone
op.State = MessageTaskStateEnabled
return this.SaveInt64(tx, op)
}
// FindSendingMessageTasks 查找需要发送的任务
func (this *MessageTaskDAO) FindSendingMessageTasks(tx *dbs.Tx, size int64) (result []*MessageTask, err error) {
if size <= 0 {
return nil, nil
}
_, err = this.Query(tx).
State(MessageTaskStateEnabled).
Attr("status", MessageTaskStatusNone).
Where("(recipientId=0 OR recipientId IN (SELECT id FROM "+SharedMessageRecipientDAO.Table+" WHERE state=1 AND isOn=1 AND (timeFrom IS NULL OR timeTo IS NULL OR :time BETWEEN timeFrom AND timeTo)))").
Param("time", timeutil.Format("H:i:s")).
Desc("isPrimary").
AscPk().
Limit(size).
Slice(&result).
FindAll()
return
}
// CountMessageTasksWithStatus 根据状态计算任务数量
func (this *MessageTaskDAO) CountMessageTasksWithStatus(tx *dbs.Tx, status MessageTaskStatus) (int64, error) {
return this.Query(tx).
State(MessageTaskStateEnabled).
Attr("status", status).
Count()
}
// ListMessageTasksWithStatus 根据状态列出单页任务
func (this *MessageTaskDAO) ListMessageTasksWithStatus(tx *dbs.Tx, status MessageTaskStatus, offset int64, size int64) (result []*MessageTask, err error) {
_, err = this.Query(tx).
State(MessageTaskStateEnabled).
Attr("status", status).
Desc("isPrimary").
AscPk().
Offset(offset).
Limit(size).
Slice(&result).
FindAll()
return
}
// UpdateMessageTaskStatus 设置发送的状态
func (this *MessageTaskDAO) UpdateMessageTaskStatus(tx *dbs.Tx, taskId int64, status MessageTaskStatus, result []byte) error {
if taskId <= 0 {
return errors.New("invalid taskId")
}
var op = NewMessageTaskOperator()
op.Id = taskId
op.Status = status
op.SentAt = time.Now().Unix()
if len(result) > 0 {
op.Result = result
}
return this.Save(tx, op)
}
// CreateMessageTasks 从集群、节点或者服务中创建任务
func (this *MessageTaskDAO) CreateMessageTasks(tx *dbs.Tx, role nodeconfigs.NodeRole, clusterId int64, nodeId int64, serverId int64, messageType MessageType, subject string, body string) error {
if !teaconst.IsPlus {
return nil
}
receivers, err := SharedMessageReceiverDAO.FindEnabledBestFitReceivers(tx, role, clusterId, nodeId, serverId, messageType)
if err != nil {
return err
}
allRecipientIds := []int64{}
for _, receiver := range receivers {
if receiver.RecipientId > 0 {
allRecipientIds = append(allRecipientIds, int64(receiver.RecipientId))
} else if receiver.RecipientGroupId > 0 {
recipientIds, err := SharedMessageRecipientDAO.FindAllEnabledAndOnRecipientIdsWithGroup(tx, int64(receiver.RecipientGroupId))
if err != nil {
return err
}
allRecipientIds = append(allRecipientIds, recipientIds...)
}
}
sentMap := map[int64]bool{} // recipientId => bool 用来检查是否已经发送,防止重复发送给某个接收人
for _, recipientId := range allRecipientIds {
_, ok := sentMap[recipientId]
if ok {
continue
}
sentMap[recipientId] = true
_, err := this.CreateMessageTask(tx, recipientId, 0, "", subject, body, false)
if err != nil {
return err
}
}
return nil
}
// CleanExpiredMessageTasks 清理 // CleanExpiredMessageTasks 清理
func (this *MessageTaskDAO) CleanExpiredMessageTasks(tx *dbs.Tx, days int) error { func (this *MessageTaskDAO) CleanExpiredMessageTasks(tx *dbs.Tx, days int) error {
if days <= 0 { if days <= 0 {

View File

@@ -0,0 +1,14 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
)
// CreateMessageTasks 从集群、节点或者服务中创建任务
func (this *MessageTaskDAO) CreateMessageTasks(tx *dbs.Tx, role nodeconfigs.NodeRole, clusterId int64, nodeId int64, serverId int64, messageType MessageType, subject string, body string) error {
return nil
}

View File

@@ -8,20 +8,6 @@ import (
"testing" "testing"
) )
func TestMessageTaskDAO_FindSendingMessageTasks(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
tasks, err := models.NewMessageTaskDAO().FindSendingMessageTasks(tx, 100)
if err != nil {
t.Fatal(err)
}
t.Log(len(tasks), "tasks")
for _, task := range tasks {
t.Log("task:", task.Id, "recipient:", task.RecipientId)
}
}
func TestMessageTaskDAO_CleanExpiredMessageTasks(t *testing.T) { func TestMessageTaskDAO_CleanExpiredMessageTasks(t *testing.T) {
var dao = models.NewMessageTaskDAO() var dao = models.NewMessageTaskDAO()
var tx *dbs.Tx var tx *dbs.Tx

View File

@@ -5,7 +5,6 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/goman" "github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
@@ -14,8 +13,10 @@ import (
"github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -40,6 +41,8 @@ func init() {
const MetricStatTablePartials = 20 // 表格Partial数量 const MetricStatTablePartials = 20 // 表格Partial数量
var metricHashRegexp = regexp.MustCompile(`^\w+$`)
func NewMetricStatDAO() *MetricStatDAO { func NewMetricStatDAO() *MetricStatDAO {
return dbs.NewDAO(&MetricStatDAO{ return dbs.NewDAO(&MetricStatDAO{
DAOObject: dbs.DAOObject{ DAOObject: dbs.DAOObject{
@@ -124,18 +127,30 @@ func (this *MetricStatDAO) DeleteItemStats(tx *dbs.Tx, itemId int64) error {
} }
// DeleteNodeItemStats 删除某个节点的统计数据 // DeleteNodeItemStats 删除某个节点的统计数据
func (this *MetricStatDAO) DeleteNodeItemStats(tx *dbs.Tx, nodeId int64, serverId int64, itemId int64, time string) error { func (this *MetricStatDAO) DeleteNodeItemStats(tx *dbs.Tx, nodeId int64, serverId int64, itemId int64, time string, keepKeys []string) error {
if serverId > 0 { if serverId > 0 {
_, err := this.Query(tx). var query = this.Query(tx).
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("nodeId", nodeId). Attr("nodeId", nodeId).
Attr("serverId", serverId). Attr("serverId", serverId).
Attr("itemId", itemId). Attr("itemId", itemId).
Attr("time", time). Attr("time", time)
Delete() if len(keepKeys) > 0 {
if this.canIgnore(err) { query.Reuse(false)
var s []string
for _, k := range keepKeys {
if metricHashRegexp.MatchString(k) {
s = append(s, "'"+k+"@"+types.String(nodeId)+"'")
}
}
query.Where("hash NOT IN (" + strings.Join(s, ",") + ")")
}
err := query.
DeleteQuickly()
if err == nil || this.canIgnore(err) {
return nil return nil
} }
return err return err
} }
@@ -759,10 +774,5 @@ func (this *MetricStatDAO) canIgnore(err error) bool {
} }
// 忽略 Error 1213: Deadlock found 错误 // 忽略 Error 1213: Deadlock found 错误
mysqlErr, ok := err.(*mysql.MySQLError) return CheckSQLErrCode(err, 1213)
if ok && mysqlErr.Number == 1213 {
return true
}
return false
} }

View File

@@ -41,7 +41,7 @@ func TestMetricStatDAO_DeleteNodeItemStats(t *testing.T) {
defer func() { defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms") t.Log(time.Since(before).Seconds()*1000, "ms")
}() }()
err := dao.DeleteNodeItemStats(nil, 1, 0, 1, timeutil.Format("Ymd")) err := dao.DeleteNodeItemStats(nil, 1, 0, 1, timeutil.Format("Ymd"), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -4,7 +4,6 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/goman" "github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
@@ -289,10 +288,5 @@ func (this *MetricSumStatDAO) canIgnore(err error) bool {
} }
// 忽略 Error 1213: Deadlock found 错误 // 忽略 Error 1213: Deadlock found 错误
mysqlErr, ok := err.(*mysql.MySQLError) return CheckSQLErrCode(err, 1213)
if ok && mysqlErr.Number == 1213 {
return true
}
return false
} }

View File

@@ -1,215 +0,0 @@
package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
)
const (
MonitorNodeStateEnabled = 1 // 已启用
MonitorNodeStateDisabled = 0 // 已禁用
)
type MonitorNodeDAO dbs.DAO
func NewMonitorNodeDAO() *MonitorNodeDAO {
return dbs.NewDAO(&MonitorNodeDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeMonitorNodes",
Model: new(MonitorNode),
PkName: "id",
},
}).(*MonitorNodeDAO)
}
var SharedMonitorNodeDAO *MonitorNodeDAO
func init() {
dbs.OnReady(func() {
SharedMonitorNodeDAO = NewMonitorNodeDAO()
})
}
// EnableMonitorNode 启用条目
func (this *MonitorNodeDAO) EnableMonitorNode(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
Pk(id).
Set("state", MonitorNodeStateEnabled).
Update()
return err
}
// DisableMonitorNode 禁用条目
func (this *MonitorNodeDAO) DisableMonitorNode(tx *dbs.Tx, nodeId int64) error {
_, err := this.Query(tx).
Pk(nodeId).
Set("state", MonitorNodeStateDisabled).
Update()
if err != nil {
return err
}
// 删除运行日志
return SharedNodeLogDAO.DeleteNodeLogs(tx, nodeconfigs.NodeRoleMonitor, nodeId)
}
// FindEnabledMonitorNode 查找启用中的条目
func (this *MonitorNodeDAO) FindEnabledMonitorNode(tx *dbs.Tx, id int64) (*MonitorNode, error) {
result, err := this.Query(tx).
Pk(id).
Attr("state", MonitorNodeStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*MonitorNode), err
}
// FindMonitorNodeName 根据主键查找名称
func (this *MonitorNodeDAO) FindMonitorNodeName(tx *dbs.Tx, id int64) (string, error) {
return this.Query(tx).
Pk(id).
Result("name").
FindStringCol("")
}
// FindAllEnabledMonitorNodes 列出所有可用监控节点
func (this *MonitorNodeDAO) FindAllEnabledMonitorNodes(tx *dbs.Tx) (result []*MonitorNode, err error) {
_, err = this.Query(tx).
State(MonitorNodeStateEnabled).
Desc("order").
AscPk().
Slice(&result).
FindAll()
return
}
// CountAllEnabledMonitorNodes 计算监控节点数量
func (this *MonitorNodeDAO) CountAllEnabledMonitorNodes(tx *dbs.Tx) (int64, error) {
return this.Query(tx).
State(MonitorNodeStateEnabled).
Count()
}
// ListEnabledMonitorNodes 列出单页的监控节点
func (this *MonitorNodeDAO) ListEnabledMonitorNodes(tx *dbs.Tx, offset int64, size int64) (result []*MonitorNode, err error) {
_, err = this.Query(tx).
State(MonitorNodeStateEnabled).
Offset(offset).
Limit(size).
Desc("order").
DescPk().
Slice(&result).
FindAll()
return
}
// CreateMonitorNode 创建监控节点
func (this *MonitorNodeDAO) CreateMonitorNode(tx *dbs.Tx, name string, description string, isOn bool) (nodeId int64, err error) {
uniqueId, err := this.GenUniqueId(tx)
if err != nil {
return 0, err
}
secret := rands.String(32)
err = NewApiTokenDAO().CreateAPIToken(tx, uniqueId, secret, nodeconfigs.NodeRoleMonitor)
if err != nil {
return
}
var op = NewMonitorNodeOperator()
op.IsOn = isOn
op.UniqueId = uniqueId
op.Secret = secret
op.Name = name
op.Description = description
op.State = NodeStateEnabled
err = this.Save(tx, op)
if err != nil {
return
}
return types.Int64(op.Id), nil
}
// UpdateMonitorNode 修改监控节点
func (this *MonitorNodeDAO) UpdateMonitorNode(tx *dbs.Tx, nodeId int64, name string, description string, isOn bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
var op = NewMonitorNodeOperator()
op.Id = nodeId
op.Name = name
op.Description = description
op.IsOn = isOn
err := this.Save(tx, op)
return err
}
// FindEnabledMonitorNodeWithUniqueId 根据唯一ID获取节点信息
func (this *MonitorNodeDAO) FindEnabledMonitorNodeWithUniqueId(tx *dbs.Tx, uniqueId string) (*MonitorNode, error) {
result, err := this.Query(tx).
Attr("uniqueId", uniqueId).
Attr("state", MonitorNodeStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*MonitorNode), err
}
// FindEnabledMonitorNodeIdWithUniqueId 根据唯一ID获取节点ID
func (this *MonitorNodeDAO) FindEnabledMonitorNodeIdWithUniqueId(tx *dbs.Tx, uniqueId string) (int64, error) {
return this.Query(tx).
Attr("uniqueId", uniqueId).
Attr("state", MonitorNodeStateEnabled).
ResultPk().
FindInt64Col(0)
}
// GenUniqueId 生成唯一ID
func (this *MonitorNodeDAO) GenUniqueId(tx *dbs.Tx) (string, error) {
for {
uniqueId := rands.HexString(32)
ok, err := this.Query(tx).
Attr("uniqueId", uniqueId).
Exist()
if err != nil {
return "", err
}
if ok {
continue
}
return uniqueId, nil
}
}
// UpdateNodeStatus 更改节点状态
func (this *MonitorNodeDAO) UpdateNodeStatus(tx *dbs.Tx, nodeId int64, statusJSON []byte) error {
if statusJSON == nil {
return nil
}
_, err := this.Query(tx).
Pk(nodeId).
Set("status", string(statusJSON)).
Update()
return err
}
// CountAllLowerVersionNodes 计算所有节点中低于某个版本的节点数量
func (this *MonitorNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (int64, error) {
return this.Query(tx).
State(MonitorNodeStateEnabled).
Attr("isOn", true).
Where("status IS NOT NULL").
Where("(JSON_EXTRACT(status, '$.buildVersionCode') IS NULL OR JSON_EXTRACT(status, '$.buildVersionCode')<:version)").
Param("version", utils.VersionToLong(version)).
Count()
}

View File

@@ -1,38 +0,0 @@
package models
import "github.com/iwind/TeaGo/dbs"
// MonitorNode 监控节点
type MonitorNode struct {
Id uint32 `field:"id"` // ID
IsOn bool `field:"isOn"` // 是否启用
UniqueId string `field:"uniqueId"` // 唯一ID
Secret string `field:"secret"` // 密钥
Name string `field:"name"` // 名称
Description string `field:"description"` // 描述
Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
AdminId uint32 `field:"adminId"` // 管理员ID
Weight uint32 `field:"weight"` // 权重
Status dbs.JSON `field:"status"` // 运行状态
}
type MonitorNodeOperator struct {
Id interface{} // ID
IsOn interface{} // 是否启用
UniqueId interface{} // 唯一ID
Secret interface{} // 密钥
Name interface{} // 名称
Description interface{} // 描述
Order interface{} // 排序
State interface{} // 状态
CreatedAt interface{} // 创建时间
AdminId interface{} // 管理员ID
Weight interface{} // 权重
Status interface{} // 运行状态
}
func NewMonitorNodeOperator() *MonitorNodeOperator {
return &MonitorNodeOperator{}
}

View File

@@ -126,7 +126,7 @@ func (this *NodeClusterDAO) FindAllEnableClusterIds(tx *dbs.Tx) (result []int64,
} }
// CreateCluster 创建集群 // CreateCluster 创建集群
func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string, dnsTTL int32, cachePolicyId int64, httpFirewallPolicyId int64, systemServices map[string]maps.Map, globalServerConfig *serverconfigs.GlobalServerConfig, autoInstallNftables bool) (clusterId int64, err error) { func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string, dnsTTL int32, cachePolicyId int64, httpFirewallPolicyId int64, systemServices map[string]maps.Map, globalServerConfig *serverconfigs.GlobalServerConfig, autoInstallNftables bool, autoSystemTuning bool, autoTrimDisks bool) (clusterId int64, err error) {
uniqueId, err := this.GenUniqueId(tx) uniqueId, err := this.GenUniqueId(tx)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -189,6 +189,8 @@ func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string
op.UniqueId = uniqueId op.UniqueId = uniqueId
op.Secret = secret op.Secret = secret
op.AutoInstallNftables = autoInstallNftables op.AutoInstallNftables = autoInstallNftables
op.AutoSystemTuning = autoSystemTuning
op.AutoTrimDisks = autoTrimDisks
op.State = NodeClusterStateEnabled op.State = NodeClusterStateEnabled
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
@@ -199,7 +201,7 @@ func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string
} }
// UpdateCluster 修改集群 // UpdateCluster 修改集群
func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name string, grantId int64, installDir string, timezone string, nodeMaxThreads int32, autoOpenPorts bool, clockConfig *nodeconfigs.ClockConfig, autoRemoteStart bool, autoInstallTables bool, sshParams *nodeconfigs.SSHParams) error { func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name string, grantId int64, installDir string, timezone string, nodeMaxThreads int32, autoOpenPorts bool, clockConfig *nodeconfigs.ClockConfig, autoRemoteStart bool, autoInstallTables bool, sshParams *nodeconfigs.SSHParams, autoSystemTuning bool, autoTrimDisks bool) error {
if clusterId <= 0 { if clusterId <= 0 {
return errors.New("invalid clusterId") return errors.New("invalid clusterId")
} }
@@ -226,6 +228,8 @@ func (this *NodeClusterDAO) UpdateCluster(tx *dbs.Tx, clusterId int64, name stri
op.AutoRemoteStart = autoRemoteStart op.AutoRemoteStart = autoRemoteStart
op.AutoInstallNftables = autoInstallTables op.AutoInstallNftables = autoInstallTables
op.AutoSystemTuning = autoSystemTuning
op.AutoTrimDisks = autoTrimDisks
if sshParams != nil { if sshParams != nil {
sshParamsJSON, err := json.Marshal(sshParams) sshParamsJSON, err := json.Marshal(sshParams)
@@ -262,13 +266,22 @@ func (this *NodeClusterDAO) CountAllEnabledClusters(tx *dbs.Tx, keyword string)
} }
// ListEnabledClusters 列出单页集群 // ListEnabledClusters 列出单页集群
func (this *NodeClusterDAO) ListEnabledClusters(tx *dbs.Tx, keyword string, offset, size int64) (result []*NodeCluster, err error) { func (this *NodeClusterDAO) ListEnabledClusters(tx *dbs.Tx, keyword string, idDesc bool, idAsc bool, offset, size int64) (result []*NodeCluster, err error) {
var query = this.Query(tx). var query = this.Query(tx).
State(NodeClusterStateEnabled) State(NodeClusterStateEnabled)
if len(keyword) > 0 { if len(keyword) > 0 {
query.Where("(name LIKE :keyword OR dnsName like :keyword OR (dnsDomainId > 0 AND dnsDomainId IN (SELECT id FROM "+dns.SharedDNSDomainDAO.Table+" WHERE name LIKE :keyword AND state=1)))"). query.Where("(name LIKE :keyword OR dnsName like :keyword OR (dnsDomainId > 0 AND dnsDomainId IN (SELECT id FROM "+dns.SharedDNSDomainDAO.Table+" WHERE name LIKE :keyword AND state=1)))").
Param("keyword", dbutils.QuoteLike(keyword)) Param("keyword", dbutils.QuoteLike(keyword))
} }
if idDesc {
query.DescPk()
} else if idAsc {
query.AscPk()
} else {
query.Desc("isPinned").DescPk()
}
_, err = query. _, err = query.
Result( Result(
NodeClusterField_Id, NodeClusterField_Id,
@@ -293,8 +306,6 @@ func (this *NodeClusterDAO) ListEnabledClusters(tx *dbs.Tx, keyword string, offs
Offset(offset). Offset(offset).
Limit(size). Limit(size).
Slice(&result). Slice(&result).
Desc("isPinned").
DescPk().
FindAll() FindAll()
return return
@@ -647,10 +658,10 @@ func (this *NodeClusterDAO) FindClusterTOAConfig(tx *dbs.Tx, clusterId int64, ca
return nil, err return nil, err
} }
if !IsNotNull([]byte(toa)) { if !IsNotNull([]byte(toa)) {
return nodeconfigs.DefaultTOAConfig(), nil return nodeconfigs.NewTOAConfig(), nil
} }
config := &nodeconfigs.TOAConfig{} var config = nodeconfigs.NewTOAConfig()
err = json.Unmarshal([]byte(toa), config) err = json.Unmarshal([]byte(toa), config)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -675,7 +686,7 @@ func (this *NodeClusterDAO) UpdateClusterTOA(tx *dbs.Tx, clusterId int64, toaJSO
if err != nil { if err != nil {
return err return err
} }
return this.NotifyUpdate(tx, clusterId) return this.NotifyTOAUpdate(tx, clusterId)
} }
// CountAllEnabledNodeClustersWithHTTPCachePolicyId 计算使用某个缓存策略的集群数量 // CountAllEnabledNodeClustersWithHTTPCachePolicyId 计算使用某个缓存策略的集群数量
@@ -950,11 +961,12 @@ func (this *NodeClusterDAO) GenUniqueId(tx *dbs.Tx) (string, error) {
// FindLatestNodeClusters 查询最近访问的集群 // FindLatestNodeClusters 查询最近访问的集群
func (this *NodeClusterDAO) FindLatestNodeClusters(tx *dbs.Tx, size int64) (result []*NodeCluster, err error) { func (this *NodeClusterDAO) FindLatestNodeClusters(tx *dbs.Tx, size int64) (result []*NodeCluster, err error) {
itemTable := SharedLatestItemDAO.Table var itemTable = SharedLatestItemDAO.Table
itemType := LatestItemTypeCluster var itemType = LatestItemTypeCluster
_, err = this.Query(tx). _, err = this.Query(tx).
Result(this.Table+".id", this.Table+".name"). Result(this.Table+".id", this.Table+".name").
Join(SharedLatestItemDAO, dbs.QueryJoinRight, this.Table+".id="+itemTable+".itemId AND "+itemTable+".itemType='"+itemType+"'"). Join(SharedLatestItemDAO, dbs.QueryJoinRight, this.Table+".id="+itemTable+".itemId AND "+itemTable+".itemType='"+itemType+"'").
Where(itemTable + ".updatedAt<=UNIX_TIMESTAMP()"). // VERY IMPORTANT
Asc("CEIL((UNIX_TIMESTAMP() - " + itemTable + ".updatedAt) / (7 * 86400))"). // 优先一个星期以内的 Asc("CEIL((UNIX_TIMESTAMP() - " + itemTable + ".updatedAt) / (7 * 86400))"). // 优先一个星期以内的
Desc(itemTable + ".count"). Desc(itemTable + ".count").
State(NodeClusterStateEnabled). State(NodeClusterStateEnabled).
@@ -1018,7 +1030,7 @@ func (this *NodeClusterDAO) FindClusterBasicInfo(tx *dbs.Tx, clusterId int64, ca
cluster, err := this.Query(tx). cluster, err := this.Query(tx).
Pk(clusterId). Pk(clusterId).
State(NodeClusterStateEnabled). State(NodeClusterStateEnabled).
Result("id", "name", "timeZone", "nodeMaxThreads", "cachePolicyId", "httpFirewallPolicyId", "autoOpenPorts", "webp", "uam", "cc", "httpPages", "http3", "isOn", "ddosProtection", "clock", "globalServerConfig", "autoInstallNftables"). Result("id", "name", "timeZone", "nodeMaxThreads", "cachePolicyId", "httpFirewallPolicyId", "autoOpenPorts", "webp", "uam", "cc", "httpPages", "http3", "isOn", "ddosProtection", "clock", "globalServerConfig", "autoInstallNftables", "autoSystemTuning", "networkSecurity", "autoTrimDisks", "secret").
Find() Find()
if err != nil || cluster == nil { if err != nil || cluster == nil {
return nil, err return nil, err
@@ -1040,7 +1052,7 @@ func (this *NodeClusterDAO) UpdateClusterWebPPolicy(tx *dbs.Tx, clusterId int64,
return err return err
} }
return this.NotifyUpdate(tx, clusterId) return this.NotifyWebPPolicyUpdate(tx, clusterId)
} }
webpPolicyJSON, err := json.Marshal(webpPolicy) webpPolicyJSON, err := json.Marshal(webpPolicy)
@@ -1055,7 +1067,7 @@ func (this *NodeClusterDAO) UpdateClusterWebPPolicy(tx *dbs.Tx, clusterId int64,
return err return err
} }
return this.NotifyUpdate(tx, clusterId) return this.NotifyWebPPolicyUpdate(tx, clusterId)
} }
// FindClusterWebPPolicy 查询WebP设置 // FindClusterWebPPolicy 查询WebP设置
@@ -1080,7 +1092,7 @@ func (this *NodeClusterDAO) FindClusterWebPPolicy(tx *dbs.Tx, clusterId int64, c
return nodeconfigs.DefaultWebPImagePolicy, nil return nodeconfigs.DefaultWebPImagePolicy, nil
} }
var policy = &nodeconfigs.WebPImagePolicy{} var policy = nodeconfigs.NewWebPImagePolicy()
err = json.Unmarshal(webpJSON, policy) err = json.Unmarshal(webpJSON, policy)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1266,6 +1278,57 @@ func (this *NodeClusterDAO) FindClusterHTTP3Policy(tx *dbs.Tx, clusterId int64,
return policy, nil return policy, nil
} }
// UpdateClusterNetworkSecurityPolicy 修改网络安全策略设置
func (this *NodeClusterDAO) UpdateClusterNetworkSecurityPolicy(tx *dbs.Tx, clusterId int64, networkSecurityPolicy *nodeconfigs.NetworkSecurityPolicy) error {
if networkSecurityPolicy == nil {
networkSecurityPolicy = nodeconfigs.NewNetworkSecurityPolicy()
}
networkSecurityPolicyJSON, err := json.Marshal(networkSecurityPolicy)
if err != nil {
return err
}
err = this.Query(tx).
Pk(clusterId).
Set("networkSecurity", networkSecurityPolicyJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyNetworkSecurityUpdate(tx, clusterId)
}
// FindClusterNetworkSecurityPolicy 查询网络安全策略设置
func (this *NodeClusterDAO) FindClusterNetworkSecurityPolicy(tx *dbs.Tx, clusterId int64, cacheMap *utils.CacheMap) (*nodeconfigs.NetworkSecurityPolicy, error) {
var cacheKey = this.Table + ":FindClusterNetworkSecurityPolicy:" + types.String(clusterId)
if cacheMap != nil {
cache, ok := cacheMap.Get(cacheKey)
if ok {
return cache.(*nodeconfigs.NetworkSecurityPolicy), nil
}
}
networkSecurityPolicyJSON, err := this.Query(tx).
Pk(clusterId).
Result("networkSecurity").
FindJSONCol()
if err != nil {
return nil, err
}
if IsNull(networkSecurityPolicyJSON) {
return nodeconfigs.NewNetworkSecurityPolicy(), nil
}
var policy = nodeconfigs.NewNetworkSecurityPolicy()
err = json.Unmarshal(networkSecurityPolicyJSON, policy)
if err != nil {
return nil, err
}
return policy, nil
}
// UpdateClusterHTTPPagesPolicy 修改自定义页面设置 // UpdateClusterHTTPPagesPolicy 修改自定义页面设置
func (this *NodeClusterDAO) UpdateClusterHTTPPagesPolicy(tx *dbs.Tx, clusterId int64, httpPagesPolicy *nodeconfigs.HTTPPagesPolicy) error { func (this *NodeClusterDAO) UpdateClusterHTTPPagesPolicy(tx *dbs.Tx, clusterId int64, httpPagesPolicy *nodeconfigs.HTTPPagesPolicy) error {
if httpPagesPolicy == nil { if httpPagesPolicy == nil {
@@ -1449,17 +1512,28 @@ func (this *NodeClusterDAO) NotifyHTTP3Update(tx *dbs.Tx, clusterId int64) error
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTP3PolicyChanged) return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTP3PolicyChanged)
} }
// NotifyNetworkSecurityUpdate 通知网络安全策略更新
func (this *NodeClusterDAO) NotifyNetworkSecurityUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeNetworkSecurityPolicyChanged)
}
// NotifyHTTPPagesPolicyUpdate 通知HTTP Pages更新 // NotifyHTTPPagesPolicyUpdate 通知HTTP Pages更新
func (this *NodeClusterDAO) NotifyHTTPPagesPolicyUpdate(tx *dbs.Tx, clusterId int64) error { func (this *NodeClusterDAO) NotifyHTTPPagesPolicyUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTPPagesPolicyChanged) return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeHTTPPagesPolicyChanged)
} }
// NotifyTOAUpdate 通知TOA变化
func (this *NodeClusterDAO) NotifyTOAUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeTOAChanged)
}
// NotifyWebPPolicyUpdate 通知WebP策略更新
func (this *NodeClusterDAO) NotifyWebPPolicyUpdate(tx *dbs.Tx, clusterId int64) error {
return SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, clusterId, 0, 0, NodeTaskTypeWebPPolicyChanged)
}
// NotifyDNSUpdate 通知DNS更新 // NotifyDNSUpdate 通知DNS更新
// TODO 更新新的DNS解析记录的同时需要删除老的DNS解析记录 // TODO 更新新的DNS解析记录的同时需要删除老的DNS解析记录
func (this *NodeClusterDAO) NotifyDNSUpdate(tx *dbs.Tx, clusterId int64) error { func (this *NodeClusterDAO) NotifyDNSUpdate(tx *dbs.Tx, clusterId int64) error {
err := dns.SharedDNSTaskDAO.CreateClusterTask(tx, clusterId, dns.DNSTaskTypeClusterChange) return dns.SharedDNSTaskDAO.CreateClusterTask(tx, clusterId, dns.DNSTaskTypeClusterChange)
if err != nil {
return err
}
return nil
} }

View File

@@ -43,6 +43,9 @@ const (
NodeClusterField_HttpPages dbs.FieldName = "httpPages" // 自定义页面设置 NodeClusterField_HttpPages dbs.FieldName = "httpPages" // 自定义页面设置
NodeClusterField_Cc dbs.FieldName = "cc" // CC设置 NodeClusterField_Cc dbs.FieldName = "cc" // CC设置
NodeClusterField_Http3 dbs.FieldName = "http3" // HTTP3设置 NodeClusterField_Http3 dbs.FieldName = "http3" // HTTP3设置
NodeClusterField_AutoSystemTuning dbs.FieldName = "autoSystemTuning" // 是否自动调整系统参数
NodeClusterField_NetworkSecurity dbs.FieldName = "networkSecurity" // 网络安全策略
NodeClusterField_AutoTrimDisks dbs.FieldName = "autoTrimDisks" // 是否自动执行TRIM
) )
// NodeCluster 节点集群 // NodeCluster 节点集群
@@ -87,6 +90,9 @@ type NodeCluster struct {
HttpPages dbs.JSON `field:"httpPages"` // 自定义页面设置 HttpPages dbs.JSON `field:"httpPages"` // 自定义页面设置
Cc dbs.JSON `field:"cc"` // CC设置 Cc dbs.JSON `field:"cc"` // CC设置
Http3 dbs.JSON `field:"http3"` // HTTP3设置 Http3 dbs.JSON `field:"http3"` // HTTP3设置
AutoSystemTuning bool `field:"autoSystemTuning"` // 是否自动调整系统参数
NetworkSecurity dbs.JSON `field:"networkSecurity"` // 网络安全策略
AutoTrimDisks bool `field:"autoTrimDisks"` // 是否自动执行TRIM
} }
type NodeClusterOperator struct { type NodeClusterOperator struct {
@@ -130,6 +136,9 @@ type NodeClusterOperator struct {
HttpPages any // 自定义页面设置 HttpPages any // 自定义页面设置
Cc any // CC设置 Cc any // CC设置
Http3 any // HTTP3设置 Http3 any // HTTP3设置
AutoSystemTuning any // 是否自动调整系统参数
NetworkSecurity any // 网络安全策略
AutoTrimDisks any // 是否自动执行TRIM
} }
func NewNodeClusterOperator() *NodeClusterOperator { func NewNodeClusterOperator() *NodeClusterOperator {

View File

@@ -37,7 +37,7 @@ func (this *NodeCluster) DecodeDDoSProtection() *ddosconfigs.ProtectionConfig {
return result return result
} }
// HasDDoSProtection 检查是否有DDOS设置 // HasDDoSProtection 检查是否有DDoS设置
func (this *NodeCluster) HasDDoSProtection() bool { func (this *NodeCluster) HasDDoSProtection() bool {
var config = this.DecodeDDoSProtection() var config = this.DecodeDDoSProtection()
if config != nil { if config != nil {
@@ -46,6 +46,27 @@ func (this *NodeCluster) HasDDoSProtection() bool {
return false return false
} }
// HasNetworkSecurityPolicy 检查是否有安全策略设置
func (this *NodeCluster) HasNetworkSecurityPolicy() bool {
var policy = this.DecodeNetworkSecurityPolicy()
if policy != nil {
return policy.IsOn()
}
return false
}
// DecodeNetworkSecurityPolicy 解析安全策略设置
func (this *NodeCluster) DecodeNetworkSecurityPolicy() *nodeconfigs.NetworkSecurityPolicy {
var policy = nodeconfigs.NewNetworkSecurityPolicy()
if IsNotNull(this.NetworkSecurity) {
err := json.Unmarshal(this.NetworkSecurity, policy)
if err != nil {
remotelogs.Error("NodeCluster.DecodeNetworkSecurityPolicy()", err.Error())
}
}
return policy
}
// DecodeClock 解析时钟配置 // DecodeClock 解析时钟配置
func (this *NodeCluster) DecodeClock() *nodeconfigs.ClockConfig { func (this *NodeCluster) DecodeClock() *nodeconfigs.ClockConfig {
var clock = nodeconfigs.DefaultClockConfig() var clock = nodeconfigs.DefaultClockConfig()

View File

@@ -18,7 +18,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
@@ -27,6 +26,7 @@ import (
"github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -882,9 +882,28 @@ func (this *NodeDAO) FindNodeStatus(tx *dbs.Tx, nodeId int64) (*nodeconfigs.Node
return status, nil return status, nil
} }
// UpdateNodeIsOn 修改节点启用状态
func (this *NodeDAO) UpdateNodeIsOn(tx *dbs.Tx, nodeId int64, isOn bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
err := this.Query(tx).
Pk(nodeId).
Set("isOn", isOn).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyDNSUpdate(tx, nodeId)
}
// UpdateNodeIsActive 更改节点在线状态 // UpdateNodeIsActive 更改节点在线状态
func (this *NodeDAO) UpdateNodeIsActive(tx *dbs.Tx, nodeId int64, isActive bool) error { func (this *NodeDAO) UpdateNodeIsActive(tx *dbs.Tx, nodeId int64, isActive bool) error {
b := "true" if nodeId <= 0 {
return errors.New("invalid nodeId")
}
var b = "true"
if !isActive { if !isActive {
b = "false" b = "false"
} }
@@ -898,6 +917,9 @@ func (this *NodeDAO) UpdateNodeIsActive(tx *dbs.Tx, nodeId int64, isActive bool)
// UpdateNodeIsInstalled 设置节点安装状态 // UpdateNodeIsInstalled 设置节点安装状态
func (this *NodeDAO) UpdateNodeIsInstalled(tx *dbs.Tx, nodeId int64, isInstalled bool) error { func (this *NodeDAO) UpdateNodeIsInstalled(tx *dbs.Tx, nodeId int64, isInstalled bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(nodeId). Pk(nodeId).
Set("isInstalled", isInstalled). Set("isInstalled", isInstalled).
@@ -1018,6 +1040,13 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
} }
config.AllowedIPs = append(config.AllowedIPs, apiNodeIPs...) config.AllowedIPs = append(config.AllowedIPs, apiNodeIPs...)
// 当前的节点IP地址
nodeNodeIPs, err := SharedNodeIPAddressDAO.FindAllEnabledAddressStringsWithNode(tx, nodeId, nodeconfigs.NodeRoleNode)
if err != nil {
return nil, err
}
config.IPAddresses = nodeNodeIPs
// 所属集群 // 所属集群
var primaryClusterId = int64(node.ClusterId) var primaryClusterId = int64(node.ClusterId)
var clusterIds = []int64{primaryClusterId} var clusterIds = []int64{primaryClusterId}
@@ -1039,9 +1068,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, clusterServer := range clusterServers { servers = append(servers, clusterServers...)
servers = append(servers, clusterServer)
}
} }
for _, server := range servers { for _, server := range servers {
@@ -1059,36 +1086,15 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
} }
} }
// 全局设置
// TODO 根据用户的不同读取不同的全局设置
var settingCacheKey = "SharedSysSettingDAO:" + systemconfigs.SettingCodeServerGlobalConfig
settingJSONCache, ok := cacheMap.Get(settingCacheKey)
var settingJSON = []byte{}
if ok {
settingJSON = settingJSONCache.([]byte)
} else {
settingJSON, err = SharedSysSettingDAO.ReadSetting(tx, systemconfigs.SettingCodeServerGlobalConfig)
if err != nil {
return nil, err
}
cacheMap.Put(settingCacheKey, settingJSON)
}
if len(settingJSON) > 0 {
globalConfig := &serverconfigs.GlobalConfig{}
err = json.Unmarshal(settingJSON, globalConfig)
if err != nil {
return nil, err
}
config.GlobalConfig = globalConfig
}
var clusterIndex = 0 var clusterIndex = 0
config.WebPImagePolicies = map[int64]*nodeconfigs.WebPImagePolicy{} config.WebPImagePolicies = map[int64]*nodeconfigs.WebPImagePolicy{}
config.UAMPolicies = map[int64]*nodeconfigs.UAMPolicy{} config.UAMPolicies = map[int64]*nodeconfigs.UAMPolicy{}
config.HTTPCCPolicies = map[int64]*nodeconfigs.HTTPCCPolicy{} config.HTTPCCPolicies = map[int64]*nodeconfigs.HTTPCCPolicy{}
config.HTTP3Policies = map[int64]*nodeconfigs.HTTP3Policy{} config.HTTP3Policies = map[int64]*nodeconfigs.HTTP3Policy{}
config.HTTPPagesPolicies = map[int64]*nodeconfigs.HTTPPagesPolicy{} config.HTTPPagesPolicies = map[int64]*nodeconfigs.HTTPPagesPolicy{}
var cachePolicyIds = []int64{}
var allowIPMaps = map[string]bool{} var allowIPMaps = map[string]bool{}
for _, clusterId := range clusterIds { for _, clusterId := range clusterIds {
nodeCluster, err := SharedNodeClusterDAO.FindClusterBasicInfo(tx, clusterId, cacheMap) nodeCluster, err := SharedNodeClusterDAO.FindClusterBasicInfo(tx, clusterId, cacheMap)
@@ -1099,7 +1105,12 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
continue continue
} }
// 节点IP地址 // 集群密钥
if len(config.ClusterSecret) == 0 {
config.ClusterSecret = nodeCluster.Secret
}
// 所有节点IP地址
nodeIPAddresses, err := SharedNodeIPAddressDAO.FindAllAccessibleIPAddressesWithClusterId(tx, nodeconfigs.NodeRoleNode, clusterId, cacheMap) nodeIPAddresses, err := SharedNodeIPAddressDAO.FindAllAccessibleIPAddressesWithClusterId(tx, nodeconfigs.NodeRoleNode, clusterId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1116,7 +1127,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
// 防火墙 // 防火墙
var httpFirewallPolicyId = int64(nodeCluster.HttpFirewallPolicyId) var httpFirewallPolicyId = int64(nodeCluster.HttpFirewallPolicyId)
if httpFirewallPolicyId > 0 { if httpFirewallPolicyId > 0 {
firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId, cacheMap) firewallPolicy, err := SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, httpFirewallPolicyId, true, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1128,12 +1139,15 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
// 缓存策略 // 缓存策略
var httpCachePolicyId = int64(nodeCluster.CachePolicyId) var httpCachePolicyId = int64(nodeCluster.CachePolicyId)
if httpCachePolicyId > 0 { if httpCachePolicyId > 0 {
cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId, cacheMap) if !lists.ContainsInt64(cachePolicyIds, httpCachePolicyId) {
if err != nil { cachePolicyIds = append(cachePolicyIds, httpCachePolicyId)
return nil, err cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, httpCachePolicyId, cacheMap)
} if err != nil {
if cachePolicy != nil { return nil, err
config.HTTPCachePolicies = append(config.HTTPCachePolicies, cachePolicy) }
if cachePolicy != nil {
config.HTTPCachePolicies = append(config.HTTPCachePolicies, cachePolicy)
}
} }
} }
@@ -1164,7 +1178,7 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
// webp // webp
if IsNotNull(nodeCluster.Webp) { if IsNotNull(nodeCluster.Webp) {
var webpPolicy = &nodeconfigs.WebPImagePolicy{} var webpPolicy = nodeconfigs.NewWebPImagePolicy()
err = json.Unmarshal(nodeCluster.Webp, webpPolicy) err = json.Unmarshal(nodeCluster.Webp, webpPolicy)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1235,9 +1249,16 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, dataMap *shared
} }
} }
// 自动安装nftables // 自动安装nftables等集群配置
if clusterIndex == 0 { if clusterIndex == 0 {
config.AutoInstallNftables = nodeCluster.AutoInstallNftables config.AutoInstallNftables = nodeCluster.AutoInstallNftables
config.AutoSystemTuning = nodeCluster.AutoSystemTuning
config.AutoTrimDisks = nodeCluster.AutoTrimDisks
}
// 安全设置
if clusterIndex == 0 {
config.NetworkSecurityPolicy = nodeCluster.DecodeNetworkSecurityPolicy()
} }
clusterIndex++ clusterIndex++
@@ -2118,12 +2139,18 @@ func (this *NodeDAO) FindParentNodeConfigs(tx *dbs.Tx, nodeId int64, groupId int
var secretHash = fmt.Sprintf("%x", sha256.Sum256([]byte(node.UniqueId+"@"+node.Secret))) var secretHash = fmt.Sprintf("%x", sha256.Sum256([]byte(node.UniqueId+"@"+node.Secret)))
for _, clusterId := range node.AllClusterIds() { for _, clusterId := range node.AllClusterIds() {
parentNodeConfigs, _ := result[clusterId] var parentNodeConfigs = result[clusterId]
parentNodeConfigs = append(parentNodeConfigs, &nodeconfigs.ParentNodeConfig{ parentNodeConfigs = append(parentNodeConfigs, &nodeconfigs.ParentNodeConfig{
Id: int64(node.Id), Id: int64(node.Id),
Addrs: addrStrings, Addrs: addrStrings,
SecretHash: secretHash, SecretHash: secretHash,
}) })
// 排序
sort.Slice(parentNodeConfigs, func(i, j int) bool {
return parentNodeConfigs[i].Id < parentNodeConfigs[j].Id
})
result[clusterId] = parentNodeConfigs result[clusterId] = parentNodeConfigs
} }
} }

View File

@@ -4,7 +4,10 @@
package models package models
import ( import (
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
) )
func (this *NodeDAO) CountAllAuthorityNodes(tx *dbs.Tx) (int64, error) { func (this *NodeDAO) CountAllAuthorityNodes(tx *dbs.Tx) (int64, error) {
@@ -15,5 +18,18 @@ func (this *NodeDAO) CountAllAuthorityNodes(tx *dbs.Tx) (int64, error) {
} }
func (this *NodeDAO) CheckNodesLimit(tx *dbs.Tx) error { func (this *NodeDAO) CheckNodesLimit(tx *dbs.Tx) error {
var maxNodes = teaconst.DefaultMaxNodes
// 检查节点数量
if maxNodes > 0 {
count, err := this.CountAllAuthorityNodes(tx)
if err != nil {
return err
}
if count >= int64(maxNodes) {
return errors.New("超出最大节点数限制:" + types.String(maxNodes) + ",当前已用:" + types.String(count) + "请自行修改源码修改此限制EdgeAPI/internal/const/const_community.go 或者 购买商业版本授权。")
}
}
return nil return nil
} }

View File

@@ -89,7 +89,9 @@ func (this *NodeGrantDAO) CreateGrant(tx *dbs.Tx, adminId int64, name string, me
op.PrivateKey = privateKey op.PrivateKey = privateKey
op.Passphrase = passphrase op.Passphrase = passphrase
} }
op.Su = su if username != "root" { // only for non-root user
op.Su = su
}
op.Description = description op.Description = description
op.NodeId = nodeId op.NodeId = nodeId
op.State = NodeGrantStateEnabled op.State = NodeGrantStateEnabled
@@ -117,7 +119,11 @@ func (this *NodeGrantDAO) UpdateGrant(tx *dbs.Tx, grantId int64, name string, me
op.PrivateKey = privateKey op.PrivateKey = privateKey
op.Passphrase = passphrase op.Passphrase = passphrase
} }
op.Su = su if username != "root" { // only for non-root user
op.Su = su
} else {
op.Su = false
}
op.Description = description op.Description = description
op.NodeId = nodeId op.NodeId = nodeId
err := this.Save(tx, op) err := this.Save(tx, op)

View File

@@ -256,6 +256,32 @@ func (this *NodeIPAddressDAO) FindAllEnabledAddressesWithNode(tx *dbs.Tx, nodeId
return return
} }
// FindAllEnabledAddressStringsWithNode 查找节点的所有的IP地址地府传
func (this *NodeIPAddressDAO) FindAllEnabledAddressStringsWithNode(tx *dbs.Tx, nodeId int64, role nodeconfigs.NodeRole) (result []string, err error) {
if len(role) == 0 {
role = nodeconfigs.NodeRoleNode
}
ones, err := this.Query(tx).
Attr("nodeId", nodeId).
Attr("role", role).
State(NodeIPAddressStateEnabled).
Result("ip", "backupIP").
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
var addr = one.(*NodeIPAddress)
result = append(result, addr.Ip)
if len(addr.BackupIP) > 0 {
result = append(result, addr.BackupIP)
}
}
return
}
// FindFirstNodeAccessIPAddress 查找节点的第一个可访问的IP地址 // FindFirstNodeAccessIPAddress 查找节点的第一个可访问的IP地址
func (this *NodeIPAddressDAO) FindFirstNodeAccessIPAddress(tx *dbs.Tx, nodeId int64, mustUp bool, role nodeconfigs.NodeRole) (ip string, addrId int64, err error) { func (this *NodeIPAddressDAO) FindFirstNodeAccessIPAddress(tx *dbs.Tx, nodeId int64, mustUp bool, role nodeconfigs.NodeRole) (ip string, addrId int64, err error) {
if len(role) == 0 { if len(role) == 0 {

View File

@@ -70,8 +70,7 @@ func (this *Node) DNSRouteCodesForDomainId(dnsDomainId int64) ([]string, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
domainRoutes, _ := routes[dnsDomainId] var domainRoutes = routes[dnsDomainId]
if len(domainRoutes) > 0 { if len(domainRoutes) > 0 {
sort.Strings(domainRoutes) sort.Strings(domainRoutes)
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"strings"
"time" "time"
) )
@@ -16,19 +17,24 @@ type NodeTaskType = string
const ( const (
// CDN相关 // CDN相关
NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化 NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化
NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更 NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更
NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化 NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更 NodeTaskTypeIPListDeleted NodeTaskType = "ipListDeleted" // IPList被删除
NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化 NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更
NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化 NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化
NodeTaskTypeNodeLevelChanged NodeTaskType = "nodeLevelChanged" // 节点级别变化 NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化
NodeTaskTypeUserServersStateChanged NodeTaskType = "userServersStateChanged" // 用户服务状态变化 NodeTaskTypeNodeLevelChanged NodeTaskType = "nodeLevelChanged" // 节点级别变化
NodeTaskTypeUAMPolicyChanged NodeTaskType = "uamPolicyChanged" // UAM策略变化 NodeTaskTypeUserServersStateChanged NodeTaskType = "userServersStateChanged" // 用户服务状态变化
NodeTaskTypeHTTPPagesPolicyChanged NodeTaskType = "httpPagesPolicyChanged" // 自定义页面变化 NodeTaskTypeUAMPolicyChanged NodeTaskType = "uamPolicyChanged" // UAM策略变化
NodeTaskTypeHTTPCCPolicyChanged NodeTaskType = "httpCCPolicyChanged" // CC策略变化 NodeTaskTypeHTTPPagesPolicyChanged NodeTaskType = "httpPagesPolicyChanged" // 自定义页面变化
NodeTaskTypeHTTP3PolicyChanged NodeTaskType = "http3PolicyChanged" // HTTP3策略变化 NodeTaskTypeHTTPCCPolicyChanged NodeTaskType = "httpCCPolicyChanged" // CC策略变化
NodeTaskTypeUpdatingServers NodeTaskType = "updatingServers" // 更新一组服务 NodeTaskTypeHTTP3PolicyChanged NodeTaskType = "http3PolicyChanged" // HTTP3策略变化
NodeTaskTypeNetworkSecurityPolicyChanged NodeTaskType = "networkSecurityPolicyChanged" // 网络安全策略变化
NodeTaskTypeWebPPolicyChanged NodeTaskType = "webPPolicyChanged" // WebP策略变化
NodeTaskTypeUpdatingServers NodeTaskType = "updatingServers" // 更新一组服务
NodeTaskTypeTOAChanged NodeTaskType = "toaChanged" // TOA配置变化
NodeTaskTypePlanChanged NodeTaskType = "planChanged" // 套餐变化
// NS相关 // NS相关
@@ -234,7 +240,7 @@ func (this *NodeTaskDAO) DeleteNodeTasks(tx *dbs.Tx, role string, nodeId int64)
} }
// DeleteAllNodeTasks 删除所有节点相关任务 // DeleteAllNodeTasks 删除所有节点相关任务
func (this *NodeTaskDAO)DeleteAllNodeTasks(tx *dbs.Tx) error { func (this *NodeTaskDAO) DeleteAllNodeTasks(tx *dbs.Tx) error {
return this.Query(tx). return this.Query(tx).
DeleteQuickly() DeleteQuickly()
} }
@@ -264,6 +270,23 @@ func (this *NodeTaskDAO) FindDoingNodeTasks(tx *dbs.Tx, role string, nodeId int6
// UpdateNodeTaskDone 修改节点任务的完成状态 // UpdateNodeTaskDone 修改节点任务的完成状态
func (this *NodeTaskDAO) UpdateNodeTaskDone(tx *dbs.Tx, taskId int64, isOk bool, errorMessage string) error { func (this *NodeTaskDAO) UpdateNodeTaskDone(tx *dbs.Tx, taskId int64, isOk bool, errorMessage string) error {
if isOk {
// 特殊任务删除
taskType, err := this.Query(tx).
Pk(taskId).
Result("type").
FindStringCol("")
if err != nil {
return err
}
if strings.HasPrefix(taskType, NodeTaskTypeIPListDeleted+"@") {
return this.Query(tx).
Pk(taskId).
DeleteQuickly()
}
}
// 其他任务标记为完成
var query = this.Query(tx). var query = this.Query(tx).
Pk(taskId) Pk(taskId)
if !isOk { if !isOk {
@@ -273,8 +296,9 @@ func (this *NodeTaskDAO) UpdateNodeTaskDone(tx *dbs.Tx, taskId int64, isOk bool,
} }
query.Set("version", version) query.Set("version", version)
} }
_, err := query. _, err := query.
Set("isDone", 1). Set("isDone", true).
Set("isOk", isOk). Set("isOk", isOk).
Set("error", errorMessage). Set("error", errorMessage).
Update() Update()

View File

@@ -54,3 +54,12 @@ func TestNodeTaskDAO_FindDoingNodeTasks(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestNodeTaskDAO_UpdateNodeTaskDone(t *testing.T) {
var tx *dbs.Tx
var dao = models.NewNodeTaskDAO()
err := dao.UpdateNodeTaskDone(tx, 1741, true, "")
if err != nil {
t.Fatal(err)
}
}

View File

@@ -1,17 +1,11 @@
package models package models
import ( import (
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"strings"
"time"
) )
const ( const (
@@ -153,12 +147,13 @@ func (this *NodeThresholdDAO) FindAllEnabledAndOnClusterThresholds(tx *dbs.Tx, r
} }
// FindAllEnabledAndOnNodeThresholds 查询节点专属的阈值设置 // FindAllEnabledAndOnNodeThresholds 查询节点专属的阈值设置
func (this *NodeThresholdDAO) FindAllEnabledAndOnNodeThresholds(tx *dbs.Tx, role string, nodeId int64, item string) (result []*NodeThreshold, err error) { func (this *NodeThresholdDAO) FindAllEnabledAndOnNodeThresholds(tx *dbs.Tx, role string, clusterId int64, nodeId int64, item string) (result []*NodeThreshold, err error) {
if nodeId <= 0 { if clusterId <= 0 || nodeId <= 0 {
return return
} }
_, err = this.Query(tx). _, err = this.Query(tx).
Attr("role", role). Attr("role", role).
Attr("clusterId", clusterId).
Attr("nodeId", nodeId). Attr("nodeId", nodeId).
Attr("item", item). Attr("item", item).
Attr("isOn", true). Attr("isOn", true).
@@ -186,87 +181,3 @@ func (this *NodeThresholdDAO) CountAllEnabledThresholds(tx *dbs.Tx, role string,
query.State(NodeThresholdStateEnabled) query.State(NodeThresholdStateEnabled)
return query.Count() return query.Count()
} }
// FireNodeThreshold 触发相关阈值设置
func (this *NodeThresholdDAO) FireNodeThreshold(tx *dbs.Tx, role string, nodeId int64, item string) error {
clusterId, err := SharedNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return err
}
if clusterId == 0 {
return nil
}
// 集群相关阈值
var thresholds []*NodeThreshold
{
clusterThresholds, err := this.FindAllEnabledAndOnClusterThresholds(tx, role, clusterId, item)
if err != nil {
return err
}
thresholds = append(thresholds, clusterThresholds...)
}
// 节点相关阈值
{
nodeThresholds, err := this.FindAllEnabledAndOnNodeThresholds(tx, role, nodeId, item)
if err != nil {
return err
}
thresholds = append(thresholds, nodeThresholds...)
}
if len(thresholds) > 0 {
for _, threshold := range thresholds {
if len(threshold.Param) == 0 || threshold.Duration <= 0 {
continue
}
paramValue, err := SharedNodeValueDAO.SumNodeValues(tx, role, nodeId, item, threshold.Param, threshold.SumMethod, types.Int32(threshold.Duration), threshold.DurationUnit)
if err != nil {
return err
}
originValue := nodeconfigs.UnmarshalNodeValue(threshold.Value)
thresholdValue := types.Float64(originValue)
isMatched := nodeconfigs.CompareNodeValue(threshold.Operator, paramValue, thresholdValue)
if isMatched {
// TODO 执行其他动作
// 是否已经通知过
if threshold.NotifyDuration > 0 && threshold.NotifiedAt > 0 && time.Now().Unix()-int64(threshold.NotifiedAt) < int64(threshold.NotifyDuration*60) {
continue
}
// 创建消息
nodeName, err := SharedNodeDAO.FindNodeName(tx, nodeId)
if err != nil {
return err
}
itemName := nodeconfigs.FindNodeValueItemName(threshold.Item)
paramName := nodeconfigs.FindNodeValueItemParamName(threshold.Item, threshold.Param)
operatorName := nodeconfigs.FindNodeValueOperatorName(threshold.Operator)
subject := "节点 \"" + nodeName + "\" " + itemName + " 达到阈值"
body := "节点 \"" + nodeName + "\" " + itemName + " 达到阈值\n阈值设置" + paramName + " " + operatorName + " " + originValue + "\n当前值" + fmt.Sprintf("%.2f", paramValue) + "\n触发时间" + timeutil.Format("Y-m-d H:i:s")
if len(threshold.Message) > 0 {
body = threshold.Message
body = strings.Replace(body, "${item.name}", itemName, -1)
body = strings.Replace(body, "${value}", fmt.Sprintf("%.2f", paramValue), -1)
}
err = SharedMessageDAO.CreateNodeMessage(tx, role, clusterId, nodeId, MessageTypeThresholdSatisfied, MessageLevelWarning, subject, body, maps.Map{}.AsJSON(), true)
if err != nil {
return err
}
// 设置通知时间
_, err = this.Query(tx).
Pk(threshold.Id).
Set("notifiedAt", time.Now().Unix()).
Update()
if err != nil {
return err
}
}
}
}
return nil
}

View File

@@ -0,0 +1,12 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import "github.com/iwind/TeaGo/dbs"
// FireNodeThreshold 触发相关阈值设置
func (this *NodeThresholdDAO) FireNodeThreshold(tx *dbs.Tx, role string, nodeId int64, item string) error {
// stub
return nil
}

View File

@@ -227,6 +227,8 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
return err return err
} }
op.Oss = ossConfigJSON op.Oss = ossConfigJSON
} else {
op.Oss = dbs.SQL("NULL")
} }
op.Description = description op.Description = description
@@ -303,6 +305,19 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
return this.NotifyUpdate(tx, originId) return this.NotifyUpdate(tx, originId)
} }
// UpdateOriginIsOn 修改源站是否启用
func (this *OriginDAO) UpdateOriginIsOn(tx *dbs.Tx, originId int64, isOn bool) error {
err := this.Query(tx).
Pk(originId).
Set("isOn", isOn).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, originId)
}
// CloneOrigin 复制源站 // CloneOrigin 复制源站
func (this *OriginDAO) CloneOrigin(tx *dbs.Tx, fromOriginId int64) (newOriginId int64, err error) { func (this *OriginDAO) CloneOrigin(tx *dbs.Tx, fromOriginId int64) (newOriginId int64, err error) {
if fromOriginId <= 0 { if fromOriginId <= 0 {
@@ -400,6 +415,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *
} }
// addr // addr
var isOSS = false
if IsNotNull(origin.Addr) { if IsNotNull(origin.Addr) {
var addr = &serverconfigs.NetworkAddressConfig{} var addr = &serverconfigs.NetworkAddressConfig{}
err = json.Unmarshal(origin.Addr, addr) err = json.Unmarshal(origin.Addr, addr)
@@ -407,10 +423,11 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *
return nil, err return nil, err
} }
config.Addr = addr config.Addr = addr
isOSS = ossconfigs.IsOSSProtocol(string(addr.Protocol))
} }
// oss // oss
if IsNotNull(origin.Oss) { if isOSS && IsNotNull(origin.Oss) {
var ossConfig = ossconfigs.NewOSSConfig() var ossConfig = ossconfigs.NewOSSConfig()
err = json.Unmarshal(origin.Oss, ossConfig) err = json.Unmarshal(origin.Oss, ossConfig)
if err != nil { if err != nil {
@@ -534,6 +551,17 @@ func (this *OriginDAO) CheckUserOrigin(tx *dbs.Tx, userId int64, originId int64)
return SharedReverseProxyDAO.CheckUserReverseProxy(tx, userId, reverseProxyId) return SharedReverseProxyDAO.CheckUserReverseProxy(tx, userId, reverseProxyId)
} }
// ExistsOrigin 检查源站是否存在
func (this *OriginDAO) ExistsOrigin(tx *dbs.Tx, originId int64) (bool, error) {
if originId <= 0 {
return false, nil
}
return this.Query(tx).
Pk(originId).
State(OriginStateEnabled).
Exist()
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error { func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error {
reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId) reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId)

View File

@@ -49,7 +49,12 @@ func (this *PlanDAO) EnablePlan(tx *dbs.Tx, id uint32) error {
// DisablePlan 禁用条目 // DisablePlan 禁用条目
func (this *PlanDAO) DisablePlan(tx *dbs.Tx, id int64) error { func (this *PlanDAO) DisablePlan(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx). clusterId, err := this.FindPlanClusterId(tx, id)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(id). Pk(id).
Set("state", PlanStateDisabled). Set("state", PlanStateDisabled).
Update() Update()
@@ -57,19 +62,32 @@ func (this *PlanDAO) DisablePlan(tx *dbs.Tx, id int64) error {
return err return err
} }
return this.NotifyUpdate(tx, id) return this.NotifyUpdate(tx, id, clusterId)
} }
// FindEnabledPlan 查找启用中的条目 // FindEnabledPlan 查找启用中的条目
func (this *PlanDAO) FindEnabledPlan(tx *dbs.Tx, id int64) (*Plan, error) { func (this *PlanDAO) FindEnabledPlan(tx *dbs.Tx, planId int64, cacheMap *utils.CacheMap) (*Plan, error) {
var cacheKey = this.Table + ":FindEnabledPlan:" + types.String(planId)
if cacheMap != nil {
cache, _ := cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*Plan), nil
}
}
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(planId).
Attr("state", PlanStateEnabled). Attr("state", PlanStateEnabled).
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err
} }
return result.(*Plan), err
if cacheMap != nil {
cacheMap.Put(cacheKey, result)
}
return result.(*Plan), nil
} }
// FindPlanName 根据主键查找名称 // FindPlanName 根据主键查找名称
@@ -162,18 +180,18 @@ func (this *PlanDAO) FindEnabledPlanTrafficLimit(tx *dbs.Tx, planId int64, cache
return config, nil return config, nil
} }
// NotifyUpdate 通知变更 // FindPlanClusterId 查找套餐所属集群
func (this *PlanDAO) NotifyUpdate(tx *dbs.Tx, planId int64) error { func (this *PlanDAO) FindPlanClusterId(tx *dbs.Tx, planId int64) (clusterId int64, err error) {
// 这里不要加入状态参数,因为需要适应删除后的更新 return this.Query(tx).
clusterId, err := this.Query(tx).
Pk(planId). Pk(planId).
Result("clusterId"). Result("clusterId").
FindInt64Col(0) FindInt64Col(0)
if err != nil { }
return err
} // NotifyUpdate 通知变更
if clusterId > 0 { func (this *PlanDAO) NotifyUpdate(tx *dbs.Tx, planId int64, clusterId int64) error {
return SharedNodeClusterDAO.NotifyUpdate(tx, clusterId) if clusterId <= 0 {
} return nil
return nil }
return SharedNodeClusterDAO.NotifyUpdate(tx, clusterId)
} }

View File

@@ -0,0 +1,6 @@
package models_test
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -2,39 +2,89 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
PlanField_Id dbs.FieldName = "id" // ID
PlanField_IsOn dbs.FieldName = "isOn" // 是否启用
PlanField_Name dbs.FieldName = "name" // 套餐名
PlanField_Description dbs.FieldName = "description" // 套餐简介
PlanField_ClusterId dbs.FieldName = "clusterId" // 集群ID
PlanField_TrafficLimit dbs.FieldName = "trafficLimit" // 流量限制
PlanField_BandwidthLimitPerNode dbs.FieldName = "bandwidthLimitPerNode" // 单节点带宽限制
PlanField_Features dbs.FieldName = "features" // 允许的功能
PlanField_HasFullFeatures dbs.FieldName = "hasFullFeatures" // 是否有完整的功能
PlanField_TrafficPrice dbs.FieldName = "trafficPrice" // 流量价格设定
PlanField_BandwidthPrice dbs.FieldName = "bandwidthPrice" // 带宽价格
PlanField_MonthlyPrice dbs.FieldName = "monthlyPrice" // 月付
PlanField_SeasonallyPrice dbs.FieldName = "seasonallyPrice" // 季付
PlanField_YearlyPrice dbs.FieldName = "yearlyPrice" // 年付
PlanField_PriceType dbs.FieldName = "priceType" // 价格类型
PlanField_Order dbs.FieldName = "order" // 排序
PlanField_State dbs.FieldName = "state" // 状态
PlanField_TotalServers dbs.FieldName = "totalServers" // 可以绑定的网站数量
PlanField_TotalServerNamesPerServer dbs.FieldName = "totalServerNamesPerServer" // 每个网站可以绑定的域名数量
PlanField_TotalServerNames dbs.FieldName = "totalServerNames" // 总域名数量
PlanField_MonthlyRequests dbs.FieldName = "monthlyRequests" // 每月访问量额度
PlanField_DailyRequests dbs.FieldName = "dailyRequests" // 每日访问量额度
PlanField_DailyWebsocketConnections dbs.FieldName = "dailyWebsocketConnections" // 每日Websocket连接数
PlanField_MonthlyWebsocketConnections dbs.FieldName = "monthlyWebsocketConnections" // 每月Websocket连接数
PlanField_MaxUploadSize dbs.FieldName = "maxUploadSize" // 最大上传
)
// Plan 用户套餐 // Plan 用户套餐
type Plan struct { type Plan struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
IsOn bool `field:"isOn"` // 是否启用 IsOn bool `field:"isOn"` // 是否启用
Name string `field:"name"` // 套餐名 Name string `field:"name"` // 套餐名
ClusterId uint32 `field:"clusterId"` // 集群ID Description string `field:"description"` // 套餐简介
TrafficLimit dbs.JSON `field:"trafficLimit"` // 流量限制 ClusterId uint32 `field:"clusterId"` // 集群ID
Features dbs.JSON `field:"features"` // 允许的功能 TrafficLimit dbs.JSON `field:"trafficLimit"` // 流量限制
TrafficPrice dbs.JSON `field:"trafficPrice"` // 流量价格设定 BandwidthLimitPerNode dbs.JSON `field:"bandwidthLimitPerNode"` // 单节点带宽限制
BandwidthPrice dbs.JSON `field:"bandwidthPrice"` // 带宽价格 Features dbs.JSON `field:"features"` // 允许的功能
MonthlyPrice float64 `field:"monthlyPrice"` // 月付 HasFullFeatures bool `field:"hasFullFeatures"` // 是否有完整的功能
SeasonallyPrice float64 `field:"seasonallyPrice"` // 季付 TrafficPrice dbs.JSON `field:"trafficPrice"` // 流量价格设定
YearlyPrice float64 `field:"yearlyPrice"` // 年付 BandwidthPrice dbs.JSON `field:"bandwidthPrice"` // 带宽价格
PriceType string `field:"priceType"` // 价格类型 MonthlyPrice float64 `field:"monthlyPrice"` // 月付
Order uint32 `field:"order"` // 排序 SeasonallyPrice float64 `field:"seasonallyPrice"` // 季付
State uint8 `field:"state"` // 状态 YearlyPrice float64 `field:"yearlyPrice"` // 年付
PriceType string `field:"priceType"` // 价格类型
Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 状态
TotalServers uint32 `field:"totalServers"` // 可以绑定的网站数量
TotalServerNamesPerServer uint32 `field:"totalServerNamesPerServer"` // 每个网站可以绑定的域名数量
TotalServerNames uint32 `field:"totalServerNames"` // 总域名数量
MonthlyRequests uint64 `field:"monthlyRequests"` // 每月访问量额度
DailyRequests uint64 `field:"dailyRequests"` // 每日访问量额度
DailyWebsocketConnections uint64 `field:"dailyWebsocketConnections"` // 每日Websocket连接数
MonthlyWebsocketConnections uint64 `field:"monthlyWebsocketConnections"` // 每月Websocket连接数
MaxUploadSize dbs.JSON `field:"maxUploadSize"` // 最大上传
} }
type PlanOperator struct { type PlanOperator struct {
Id interface{} // ID Id any // ID
IsOn interface{} // 是否启用 IsOn any // 是否启用
Name interface{} // 套餐名 Name any // 套餐名
ClusterId interface{} // 集群ID Description any // 套餐简介
TrafficLimit interface{} // 流量限制 ClusterId any // 集群ID
Features interface{} // 允许的功能 TrafficLimit any // 流量限制
TrafficPrice interface{} // 流量价格设定 BandwidthLimitPerNode any // 单节点带宽限制
BandwidthPrice interface{} // 带宽价格 Features any // 允许的功能
MonthlyPrice interface{} // 月付 HasFullFeatures any // 是否有完整的功能
SeasonallyPrice interface{} // 季付 TrafficPrice any // 流量价格设定
YearlyPrice interface{} // 年付 BandwidthPrice any // 带宽价格
PriceType interface{} // 价格类型 MonthlyPrice any // 月付
Order interface{} // 排序 SeasonallyPrice any // 季付
State interface{} // 状态 YearlyPrice any // 年付
PriceType any // 价格类型
Order any // 排序
State any // 状态
TotalServers any // 可以绑定的网站数量
TotalServerNamesPerServer any // 每个网站可以绑定的域名数量
TotalServerNames any // 总域名数量
MonthlyRequests any // 每月访问量额度
DailyRequests any // 每日访问量额度
DailyWebsocketConnections any // 每日Websocket连接数
MonthlyWebsocketConnections any // 每月Websocket连接数
MaxUploadSize any // 最大上传
} }
func NewPlanOperator() *PlanOperator { func NewPlanOperator() *PlanOperator {

View File

@@ -0,0 +1,71 @@
package posts
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
)
const (
PostCategoryStateEnabled = 1 // 已启用
PostCategoryStateDisabled = 0 // 已禁用
)
type PostCategoryDAO dbs.DAO
func NewPostCategoryDAO() *PostCategoryDAO {
return dbs.NewDAO(&PostCategoryDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgePostCategories",
Model: new(PostCategory),
PkName: "id",
},
}).(*PostCategoryDAO)
}
var SharedPostCategoryDAO *PostCategoryDAO
func init() {
dbs.OnReady(func() {
SharedPostCategoryDAO = NewPostCategoryDAO()
})
}
// EnablePostCategory 启用条目
func (this *PostCategoryDAO) EnablePostCategory(tx *dbs.Tx, categoryId int64) error {
_, err := this.Query(tx).
Pk(categoryId).
Set("state", PostCategoryStateEnabled).
Update()
return err
}
// DisablePostCategory 禁用条目
func (this *PostCategoryDAO) DisablePostCategory(tx *dbs.Tx, categoryId int64) error {
_, err := this.Query(tx).
Pk(categoryId).
Set("state", PostCategoryStateDisabled).
Update()
return err
}
// FindEnabledPostCategory 查找启用中的条目
func (this *PostCategoryDAO) FindEnabledPostCategory(tx *dbs.Tx, categoryId int64) (*PostCategory, error) {
result, err := this.Query(tx).
Pk(categoryId).
State(PostCategoryStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*PostCategory), err
}
// FindPostCategoryName 根据主键查找名称
func (this *PostCategoryDAO) FindPostCategoryName(tx *dbs.Tx, categoryId int64) (string, error) {
return this.Query(tx).
Pk(categoryId).
Result("name").
FindStringCol("")
}

View File

@@ -1,4 +1,4 @@
package models package posts_test
import ( import (
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"

View File

@@ -0,0 +1,35 @@
package posts
import "github.com/iwind/TeaGo/dbs"
const (
PostCategoryField_Id dbs.FieldName = "id" // ID
PostCategoryField_Name dbs.FieldName = "name" // 分类名称
PostCategoryField_IsOn dbs.FieldName = "isOn" // 是否启用
PostCategoryField_Code dbs.FieldName = "code" // 代号
PostCategoryField_Order dbs.FieldName = "order" // 排序
PostCategoryField_State dbs.FieldName = "state" // 分类状态
)
// PostCategory 文章分类
type PostCategory struct {
Id uint32 `field:"id"` // ID
Name string `field:"name"` // 分类名称
IsOn bool `field:"isOn"` // 是否启用
Code string `field:"code"` // 代号
Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 分类状态
}
type PostCategoryOperator struct {
Id any // ID
Name any // 分类名称
IsOn any // 是否启用
Code any // 代号
Order any // 排序
State any // 分类状态
}
func NewPostCategoryOperator() *PostCategoryOperator {
return &PostCategoryOperator{}
}

View File

@@ -0,0 +1 @@
package posts

View File

@@ -0,0 +1,63 @@
package posts
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
)
const (
PostStateEnabled = 1 // 已启用
PostStateDisabled = 0 // 已禁用
)
type PostDAO dbs.DAO
func NewPostDAO() *PostDAO {
return dbs.NewDAO(&PostDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgePosts",
Model: new(Post),
PkName: "id",
},
}).(*PostDAO)
}
var SharedPostDAO *PostDAO
func init() {
dbs.OnReady(func() {
SharedPostDAO = NewPostDAO()
})
}
// EnablePost 启用条目
func (this *PostDAO) EnablePost(tx *dbs.Tx, postId int64) error {
_, err := this.Query(tx).
Pk(postId).
Set("state", PostStateEnabled).
Update()
return err
}
// DisablePost 禁用条目
func (this *PostDAO) DisablePost(tx *dbs.Tx, postId int64) error {
_, err := this.Query(tx).
Pk(postId).
Set("state", PostStateDisabled).
Update()
return err
}
// FindEnabledPost 查找启用中的条目
func (this *PostDAO) FindEnabledPost(tx *dbs.Tx, postId int64) (*Post, error) {
result, err := this.Query(tx).
Pk(postId).
State(PostStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*Post), err
}

View File

@@ -0,0 +1,6 @@
package posts_test
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,50 @@
package posts
import "github.com/iwind/TeaGo/dbs"
const (
PostField_Id dbs.FieldName = "id" // ID
PostField_CategoryId dbs.FieldName = "categoryId" // 文章分类
PostField_Type dbs.FieldName = "type" // 类型normal, url
PostField_Url dbs.FieldName = "url" // URL
PostField_Subject dbs.FieldName = "subject" // 标题
PostField_Body dbs.FieldName = "body" // 内容
PostField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
PostField_IsPublished dbs.FieldName = "isPublished" // 是否已发布
PostField_PublishedAt dbs.FieldName = "publishedAt" // 发布时间
PostField_ProductCode dbs.FieldName = "productCode" // 产品代号
PostField_State dbs.FieldName = "state" // 状态
)
// Post 文章管理
type Post struct {
Id uint32 `field:"id"` // ID
CategoryId uint32 `field:"categoryId"` // 文章分类
Type string `field:"type"` // 类型normal, url
Url string `field:"url"` // URL
Subject string `field:"subject"` // 标题
Body string `field:"body"` // 内容
CreatedAt uint64 `field:"createdAt"` // 创建时间
IsPublished bool `field:"isPublished"` // 是否已发布
PublishedAt uint64 `field:"publishedAt"` // 发布时间
ProductCode string `field:"productCode"` // 产品代号
State uint8 `field:"state"` // 状态
}
type PostOperator struct {
Id any // ID
CategoryId any // 文章分类
Type any // 类型normal, url
Url any // URL
Subject any // 标题
Body any // 内容
CreatedAt any // 创建时间
IsPublished any // 是否已发布
PublishedAt any // 发布时间
ProductCode any // 产品代号
State any // 状态
}
func NewPostOperator() *PostOperator {
return &PostOperator{}
}

View File

@@ -0,0 +1 @@
package posts

View File

@@ -98,6 +98,14 @@ func (this *RegionCountryDAO) FindRegionCountryName(tx *dbs.Tx, id int64) (strin
return name, nil return name, nil
} }
// FindRegionCountryRouteCode 查找国家|地区线路代号
func (this *RegionCountryDAO) FindRegionCountryRouteCode(tx *dbs.Tx, countryId int64) (string, error) {
return this.Query(tx).
Attr("valueId", countryId).
Result("routeCode").
FindStringCol("")
}
// FindCountryIdWithDataId 根据数据ID查找国家 // FindCountryIdWithDataId 根据数据ID查找国家
func (this *RegionCountryDAO) FindCountryIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) { func (this *RegionCountryDAO) FindCountryIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) {
return this.Query(tx). return this.Query(tx).
@@ -127,6 +135,9 @@ func (this *RegionCountryDAO) CreateCountry(tx *dbs.Tx, name string, dataId stri
pinyinResult = append(pinyinResult, strings.Join(piece, " ")) pinyinResult = append(pinyinResult, strings.Join(piece, " "))
} }
pinyinJSON, err := json.Marshal([]string{strings.Join(pinyinResult, " ")}) pinyinJSON, err := json.Marshal([]string{strings.Join(pinyinResult, " ")})
if err != nil {
return 0, err
}
op.Pinyin = pinyinJSON op.Pinyin = pinyinJSON
codes := []string{name} codes := []string{name}

View File

@@ -14,11 +14,12 @@ const (
RegionCountryField_DataId dbs.FieldName = "dataId" // 原始数据ID RegionCountryField_DataId dbs.FieldName = "dataId" // 原始数据ID
RegionCountryField_Pinyin dbs.FieldName = "pinyin" // 拼音 RegionCountryField_Pinyin dbs.FieldName = "pinyin" // 拼音
RegionCountryField_IsCommon dbs.FieldName = "isCommon" // 是否常用 RegionCountryField_IsCommon dbs.FieldName = "isCommon" // 是否常用
RegionCountryField_RouteCode dbs.FieldName = "routeCode" // 线路代号
) )
// RegionCountry 区域-国家/地区 // RegionCountry 区域-国家/地区
type RegionCountry struct { type RegionCountry struct {
Id1 uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
ValueId uint32 `field:"valueId"` // 实际ID ValueId uint32 `field:"valueId"` // 实际ID
ValueCode string `field:"valueCode"` // 值代号 ValueCode string `field:"valueCode"` // 值代号
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
@@ -29,6 +30,7 @@ type RegionCountry struct {
DataId string `field:"dataId"` // 原始数据ID DataId string `field:"dataId"` // 原始数据ID
Pinyin dbs.JSON `field:"pinyin"` // 拼音 Pinyin dbs.JSON `field:"pinyin"` // 拼音
IsCommon bool `field:"isCommon"` // 是否常用 IsCommon bool `field:"isCommon"` // 是否常用
RouteCode string `field:"routeCode"` // 线路代号
} }
type RegionCountryOperator struct { type RegionCountryOperator struct {
@@ -43,6 +45,7 @@ type RegionCountryOperator struct {
DataId any // 原始数据ID DataId any // 原始数据ID
Pinyin any // 拼音 Pinyin any // 拼音
IsCommon any // 是否常用 IsCommon any // 是否常用
RouteCode any // 线路代号
} }
func NewRegionCountryOperator() *RegionCountryOperator { func NewRegionCountryOperator() *RegionCountryOperator {

View File

@@ -11,6 +11,7 @@ import (
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"sort" "sort"
"strconv" "strconv"
"strings"
) )
const ( const (
@@ -18,6 +19,8 @@ const (
RegionProvinceStateDisabled = 0 // 已禁用 RegionProvinceStateDisabled = 0 // 已禁用
) )
var RegionProvinceSuffixes = []string{"省", "州", "区", "大区", "特区", "港", "岛", "环礁", "谷地", "山", "口岸", "郡", "县", "城", "河", "河畔", "市"}
type RegionProvinceDAO dbs.DAO type RegionProvinceDAO dbs.DAO
func NewRegionProvinceDAO() *RegionProvinceDAO { func NewRegionProvinceDAO() *RegionProvinceDAO {
@@ -77,6 +80,14 @@ func (this *RegionProvinceDAO) FindRegionProvinceName(tx *dbs.Tx, id int64) (str
FindStringCol("") FindStringCol("")
} }
// FindRegionCountryId 获取省份对应的国家|地区
func (this *RegionProvinceDAO) FindRegionCountryId(tx *dbs.Tx, provinceId int64) (int64, error) {
return this.Query(tx).
Attr("valueId", provinceId).
Result("countryId").
FindInt64Col(0)
}
// FindProvinceIdWithDataId 根据数据ID查找省份 // FindProvinceIdWithDataId 根据数据ID查找省份
func (this *RegionProvinceDAO) FindProvinceIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) { func (this *RegionProvinceDAO) FindProvinceIdWithDataId(tx *dbs.Tx, dataId string) (int64, error) {
return this.Query(tx). return this.Query(tx).
@@ -87,6 +98,37 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(tx *dbs.Tx, dataId strin
// FindProvinceIdWithName 根据省份名查找省份ID // FindProvinceIdWithName 根据省份名查找省份ID
func (this *RegionProvinceDAO) FindProvinceIdWithName(tx *dbs.Tx, countryId int64, provinceName string) (int64, error) { func (this *RegionProvinceDAO) FindProvinceIdWithName(tx *dbs.Tx, countryId int64, provinceName string) (int64, error) {
{
provinceId, err := this.findProvinceIdWithExactName(tx, countryId, provinceName)
if err != nil {
return 0, err
}
if provinceId > 0 {
return provinceId, nil
}
}
// 候选词
for _, suffix := range RegionProvinceSuffixes {
var name string
if strings.HasSuffix(provinceName, suffix) {
name = strings.TrimSuffix(provinceName, suffix)
} else {
name = provinceName + suffix
}
provinceId, err := this.findProvinceIdWithExactName(tx, countryId, name)
if err != nil {
return 0, err
}
if provinceId > 0 {
return provinceId, nil
}
}
return 0, nil
}
func (this *RegionProvinceDAO) findProvinceIdWithExactName(tx *dbs.Tx, countryId int64, provinceName string) (int64, error) {
return this.Query(tx). return this.Query(tx).
Attr("countryId", countryId). Attr("countryId", countryId).
Where("(name=:provinceName OR customName=:provinceName OR JSON_CONTAINS(codes, :provinceNameJSON) OR JSON_CONTAINS(customCodes, :provinceNameJSON))"). Where("(name=:provinceName OR customName=:provinceName OR JSON_CONTAINS(codes, :provinceNameJSON) OR JSON_CONTAINS(customCodes, :provinceNameJSON))").

View File

@@ -26,6 +26,25 @@ func TestRegionProvinceDAO_FindProvinceIdWithName(t *testing.T) {
} }
} }
func TestRegionProvinceDAO_FindProvinceIdWithName_Suffix(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
for _, name := range []string{
"维埃纳",
"维埃纳省",
"维埃纳大区",
"维埃纳市",
"维埃纳小区", // expect 0
} {
provinceId, err := SharedRegionProvinceDAO.FindProvinceIdWithName(tx, 74, name)
if err != nil {
t.Fatal(err)
}
t.Log(name, "=>", provinceId)
}
}
func TestRegionProvinceDAO_FindSimilarProvinces(t *testing.T) { func TestRegionProvinceDAO_FindSimilarProvinces(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()

View File

@@ -12,11 +12,12 @@ const (
RegionProvinceField_CustomCodes dbs.FieldName = "customCodes" // 自定义代号 RegionProvinceField_CustomCodes dbs.FieldName = "customCodes" // 自定义代号
RegionProvinceField_State dbs.FieldName = "state" // 状态 RegionProvinceField_State dbs.FieldName = "state" // 状态
RegionProvinceField_DataId dbs.FieldName = "dataId" // 原始数据ID RegionProvinceField_DataId dbs.FieldName = "dataId" // 原始数据ID
RegionProvinceField_RouteCode dbs.FieldName = "routeCode" // 线路代号
) )
// RegionProvince 区域-省份 // RegionProvince 区域-省份
type RegionProvince struct { type RegionProvince struct {
Id1 uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
ValueId uint32 `field:"valueId"` // 实际ID ValueId uint32 `field:"valueId"` // 实际ID
CountryId uint32 `field:"countryId"` // 国家ID CountryId uint32 `field:"countryId"` // 国家ID
Name string `field:"name"` // 名称 Name string `field:"name"` // 名称
@@ -25,6 +26,7 @@ type RegionProvince struct {
CustomCodes dbs.JSON `field:"customCodes"` // 自定义代号 CustomCodes dbs.JSON `field:"customCodes"` // 自定义代号
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
DataId string `field:"dataId"` // 原始数据ID DataId string `field:"dataId"` // 原始数据ID
RouteCode string `field:"routeCode"` // 线路代号
} }
type RegionProvinceOperator struct { type RegionProvinceOperator struct {
@@ -37,6 +39,7 @@ type RegionProvinceOperator struct {
CustomCodes any // 自定义代号 CustomCodes any // 自定义代号
State any // 状态 State any // 状态
DataId any // 原始数据ID DataId any // 原始数据ID
RouteCode any // 线路代号
} }
func NewRegionProvinceOperator() *RegionProvinceOperator { func NewRegionProvinceOperator() *RegionProvinceOperator {

View File

@@ -99,7 +99,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI
return nil, nil return nil, nil
} }
var config = &serverconfigs.ReverseProxyConfig{} var config = serverconfigs.NewReverseProxyConfig()
config.Id = int64(reverseProxy.Id) config.Id = int64(reverseProxy.Id)
config.IsOn = reverseProxy.IsOn config.IsOn = reverseProxy.IsOn
config.RequestHostType = types.Int8(reverseProxy.RequestHostType) config.RequestHostType = types.Int8(reverseProxy.RequestHostType)
@@ -109,6 +109,8 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI
config.StripPrefix = reverseProxy.StripPrefix config.StripPrefix = reverseProxy.StripPrefix
config.AutoFlush = reverseProxy.AutoFlush == 1 config.AutoFlush = reverseProxy.AutoFlush == 1
config.FollowRedirects = reverseProxy.FollowRedirects == 1 config.FollowRedirects = reverseProxy.FollowRedirects == 1
config.Retry50X = reverseProxy.Retry50X
config.Retry40X = reverseProxy.Retry40X
var schedulingConfig = &serverconfigs.SchedulingConfig{} var schedulingConfig = &serverconfigs.SchedulingConfig{}
if IsNotNull(reverseProxy.Scheduling) { if IsNotNull(reverseProxy.Scheduling) {
@@ -218,6 +220,8 @@ func (this *ReverseProxyDAO) CreateReverseProxy(tx *dbs.Tx, adminId int64, userI
op.AdminId = adminId op.AdminId = adminId
op.UserId = userId op.UserId = userId
op.RequestHostType = serverconfigs.RequestHostTypeProxyServer op.RequestHostType = serverconfigs.RequestHostTypeProxyServer
op.Retry50X = false
op.Retry40X = false
defaultHeaders := []string{"X-Real-IP", "X-Forwarded-For", "X-Forwarded-By", "X-Forwarded-Host", "X-Forwarded-Proto"} defaultHeaders := []string{"X-Real-IP", "X-Forwarded-For", "X-Forwarded-By", "X-Forwarded-Host", "X-Forwarded-Proto"}
defaultHeadersJSON, err := json.Marshal(defaultHeaders) defaultHeadersJSON, err := json.Marshal(defaultHeaders)
@@ -372,14 +376,14 @@ func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reversePro
} }
// UpdateReverseProxyPrimaryOrigins 修改主要源站 // UpdateReverseProxyPrimaryOrigins 修改主要源站
func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, origins []byte) error { func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, originRefs []byte) error {
if reverseProxyId <= 0 { if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId") return errors.New("invalid reverseProxyId")
} }
var op = NewReverseProxyOperator() var op = NewReverseProxyOperator()
op.Id = reverseProxyId op.Id = reverseProxyId
if len(origins) > 0 { if len(originRefs) > 0 {
op.PrimaryOrigins = origins op.PrimaryOrigins = originRefs
} else { } else {
op.PrimaryOrigins = "[]" op.PrimaryOrigins = "[]"
} }
@@ -425,7 +429,9 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx,
maxConns int32, maxConns int32,
maxIdleConns int32, maxIdleConns int32,
proxyProtocolJSON []byte, proxyProtocolJSON []byte,
followRedirects bool) error { followRedirects bool,
retry50X bool,
retry40X bool) error {
if reverseProxyId <= 0 { if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId") return errors.New("invalid reverseProxyId")
} }
@@ -490,6 +496,9 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx,
op.ProxyProtocol = proxyProtocolJSON op.ProxyProtocol = proxyProtocolJSON
} }
op.Retry50X = retry50X
op.Retry40X = retry40X
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
return err return err

View File

@@ -2,6 +2,35 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
ReverseProxyField_Id dbs.FieldName = "id" // ID
ReverseProxyField_AdminId dbs.FieldName = "adminId" // 管理员ID
ReverseProxyField_UserId dbs.FieldName = "userId" // 用户ID
ReverseProxyField_TemplateId dbs.FieldName = "templateId" // 模版ID
ReverseProxyField_IsOn dbs.FieldName = "isOn" // 是否启用
ReverseProxyField_Scheduling dbs.FieldName = "scheduling" // 调度算法
ReverseProxyField_PrimaryOrigins dbs.FieldName = "primaryOrigins" // 主要源站
ReverseProxyField_BackupOrigins dbs.FieldName = "backupOrigins" // 备用源站
ReverseProxyField_StripPrefix dbs.FieldName = "stripPrefix" // 去除URL前缀
ReverseProxyField_RequestHostType dbs.FieldName = "requestHostType" // 请求Host类型
ReverseProxyField_RequestHost dbs.FieldName = "requestHost" // 请求Host
ReverseProxyField_RequestHostExcludingPort dbs.FieldName = "requestHostExcludingPort" // 移除请求Host中的域名
ReverseProxyField_RequestURI dbs.FieldName = "requestURI" // 请求URI
ReverseProxyField_AutoFlush dbs.FieldName = "autoFlush" // 是否自动刷新缓冲区
ReverseProxyField_AddHeaders dbs.FieldName = "addHeaders" // 自动添加的Header列表
ReverseProxyField_State dbs.FieldName = "state" // 状态
ReverseProxyField_CreatedAt dbs.FieldName = "createdAt" // 创建时间
ReverseProxyField_ConnTimeout dbs.FieldName = "connTimeout" // 连接超时时间
ReverseProxyField_ReadTimeout dbs.FieldName = "readTimeout" // 读取超时时间
ReverseProxyField_IdleTimeout dbs.FieldName = "idleTimeout" // 空闲超时时间
ReverseProxyField_MaxConns dbs.FieldName = "maxConns" // 最大并发连接数
ReverseProxyField_MaxIdleConns dbs.FieldName = "maxIdleConns" // 最大空闲连接数
ReverseProxyField_ProxyProtocol dbs.FieldName = "proxyProtocol" // Proxy Protocol配置
ReverseProxyField_FollowRedirects dbs.FieldName = "followRedirects" // 回源跟随
ReverseProxyField_Retry50X dbs.FieldName = "retry50X" // 启用50X重试
ReverseProxyField_Retry40X dbs.FieldName = "retry40X" // 启用40X重试
)
// ReverseProxy 反向代理配置 // ReverseProxy 反向代理配置
type ReverseProxy struct { type ReverseProxy struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
@@ -28,33 +57,37 @@ type ReverseProxy struct {
MaxIdleConns uint32 `field:"maxIdleConns"` // 最大空闲连接数 MaxIdleConns uint32 `field:"maxIdleConns"` // 最大空闲连接数
ProxyProtocol dbs.JSON `field:"proxyProtocol"` // Proxy Protocol配置 ProxyProtocol dbs.JSON `field:"proxyProtocol"` // Proxy Protocol配置
FollowRedirects uint8 `field:"followRedirects"` // 回源跟随 FollowRedirects uint8 `field:"followRedirects"` // 回源跟随
Retry50X bool `field:"retry50X"` // 启用50X重试
Retry40X bool `field:"retry40X"` // 启用40X重试
} }
type ReverseProxyOperator struct { type ReverseProxyOperator struct {
Id interface{} // ID Id any // ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
TemplateId interface{} // 模版ID TemplateId any // 模版ID
IsOn interface{} // 是否启用 IsOn any // 是否启用
Scheduling interface{} // 调度算法 Scheduling any // 调度算法
PrimaryOrigins interface{} // 主要源站 PrimaryOrigins any // 主要源站
BackupOrigins interface{} // 备用源站 BackupOrigins any // 备用源站
StripPrefix interface{} // 去除URL前缀 StripPrefix any // 去除URL前缀
RequestHostType interface{} // 请求Host类型 RequestHostType any // 请求Host类型
RequestHost interface{} // 请求Host RequestHost any // 请求Host
RequestHostExcludingPort interface{} // 移除请求Host中的域名 RequestHostExcludingPort any // 移除请求Host中的域名
RequestURI interface{} // 请求URI RequestURI any // 请求URI
AutoFlush interface{} // 是否自动刷新缓冲区 AutoFlush any // 是否自动刷新缓冲区
AddHeaders interface{} // 自动添加的Header列表 AddHeaders any // 自动添加的Header列表
State interface{} // 状态 State any // 状态
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
ConnTimeout interface{} // 连接超时时间 ConnTimeout any // 连接超时时间
ReadTimeout interface{} // 读取超时时间 ReadTimeout any // 读取超时时间
IdleTimeout interface{} // 空闲超时时间 IdleTimeout any // 空闲超时时间
MaxConns interface{} // 最大并发连接数 MaxConns any // 最大并发连接数
MaxIdleConns interface{} // 最大空闲连接数 MaxIdleConns any // 最大空闲连接数
ProxyProtocol interface{} // Proxy Protocol配置 ProxyProtocol any // Proxy Protocol配置
FollowRedirects interface{} // 回源跟随 FollowRedirects any // 回源跟随
Retry50X any // 启用50X重试
Retry40X any // 启用40X重试
} }
func NewReverseProxyOperator() *ReverseProxyOperator { func NewReverseProxyOperator() *ReverseProxyOperator {

View File

@@ -1 +1,31 @@
package models package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/logs"
)
// DecodePrimaryOrigins 解析主要源站
func (this *ReverseProxy) DecodePrimaryOrigins() []*serverconfigs.OriginRef {
var refs = []*serverconfigs.OriginRef{}
if IsNotNull(this.PrimaryOrigins) {
err := json.Unmarshal(this.PrimaryOrigins, &refs)
if err != nil {
logs.Error(err)
}
}
return refs
}
// DecodeBackupOrigins 解析备用源站
func (this *ReverseProxy) DecodeBackupOrigins() []*serverconfigs.OriginRef {
var refs = []*serverconfigs.OriginRef{}
if IsNotNull(this.BackupOrigins) {
err := json.Unmarshal(this.BackupOrigins, &refs)
if err != nil {
logs.Error(err)
}
}
return refs
}

View File

@@ -25,7 +25,7 @@ import (
type ServerBandwidthStatDAO dbs.DAO type ServerBandwidthStatDAO dbs.DAO
const ( const (
ServerBandwidthStatTablePartials = 20 // 分表数量 ServerBandwidthStatTablePartitions = 20 // 分表数量
) )
func init() { func init() {
@@ -63,28 +63,29 @@ func init() {
} }
// UpdateServerBandwidth 写入数据 // UpdateServerBandwidth 写入数据
// 暂时不使用region区分 // 现在不需要把 userPlanId 加入到数据表unique key中因为只会影响5分钟统计影响非常有限
func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int64, serverId int64, regionId int64, day string, timeAt string, bytes int64, totalBytes int64, cachedBytes int64, attackBytes int64, countRequests int64, countCachedRequests int64, countAttackRequests int64) error { func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int64, serverId int64, regionId int64, userPlanId int64, day string, timeAt string, bandwidthBytes int64, totalBytes int64, cachedBytes int64, attackBytes int64, countRequests int64, countCachedRequests int64, countAttackRequests int64, countIPs int64) error {
if serverId <= 0 { if serverId <= 0 {
return errors.New("invalid server id '" + types.String(serverId) + "'") return errors.New("invalid server id '" + types.String(serverId) + "'")
} }
return this.Query(tx). return this.Query(tx).
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Param("bytes", bytes). Param("bytes", bandwidthBytes).
Param("totalBytes", totalBytes). Param("totalBytes", totalBytes).
Param("cachedBytes", cachedBytes). Param("cachedBytes", cachedBytes).
Param("attackBytes", attackBytes). Param("attackBytes", attackBytes).
Param("countRequests", countRequests). Param("countRequests", countRequests).
Param("countCachedRequests", countCachedRequests). Param("countCachedRequests", countCachedRequests).
Param("countAttackRequests", countAttackRequests). Param("countAttackRequests", countAttackRequests).
Param("countIPs", countIPs).
InsertOrUpdateQuickly(maps.Map{ InsertOrUpdateQuickly(maps.Map{
"userId": userId, "userId": userId,
"serverId": serverId, "serverId": serverId,
"regionId": regionId, "regionId": regionId,
"day": day, "day": day,
"timeAt": timeAt, "timeAt": timeAt,
"bytes": bytes, "bytes": bandwidthBytes,
"totalBytes": totalBytes, "totalBytes": totalBytes,
"avgBytes": totalBytes / 300, "avgBytes": totalBytes / 300,
"cachedBytes": cachedBytes, "cachedBytes": cachedBytes,
@@ -92,6 +93,8 @@ func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int
"countRequests": countRequests, "countRequests": countRequests,
"countCachedRequests": countCachedRequests, "countCachedRequests": countCachedRequests,
"countAttackRequests": countAttackRequests, "countAttackRequests": countAttackRequests,
"userPlanId": userPlanId,
"countIPs": countIPs,
}, maps.Map{ }, maps.Map{
"bytes": dbs.SQL("bytes+:bytes"), "bytes": dbs.SQL("bytes+:bytes"),
"avgBytes": dbs.SQL("(totalBytes+:totalBytes)/300"), // 因为生成SQL语句时会自动将avgBytes排在totalBytes之前所以这里不用担心先后顺序的问题 "avgBytes": dbs.SQL("(totalBytes+:totalBytes)/300"), // 因为生成SQL语句时会自动将avgBytes排在totalBytes之前所以这里不用担心先后顺序的问题
@@ -101,6 +104,7 @@ func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int
"countRequests": dbs.SQL("countRequests+:countRequests"), "countRequests": dbs.SQL("countRequests+:countRequests"),
"countCachedRequests": dbs.SQL("countCachedRequests+:countCachedRequests"), "countCachedRequests": dbs.SQL("countCachedRequests+:countCachedRequests"),
"countAttackRequests": dbs.SQL("countAttackRequests+:countAttackRequests"), "countAttackRequests": dbs.SQL("countAttackRequests+:countAttackRequests"),
"countIPs": dbs.SQL("countIPs+:countIPs"),
}) })
} }
@@ -379,14 +383,18 @@ func (this *ServerBandwidthStatDAO) FindAllServerStatsWithMonth(tx *dbs.Tx, serv
} }
// FindMonthlyPercentile 获取某月内百分位 // FindMonthlyPercentile 获取某月内百分位
func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId int64, month string, percentile int, useAvg bool) (result int64, err error) { func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId int64, month string, percentile int, useAvg bool, noPlan bool, minSamples int) (result int64, err error) {
if percentile <= 0 { if percentile <= 0 {
percentile = 95 percentile = 95
} }
// 如果是100%以上,则快速返回 // 如果是100%以上,则快速返回
if percentile >= 100 { if percentile >= 100 {
result, err = this.Query(tx). var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
result, err = query.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Result(this.bytesField(useAvg)). Result(this.bytesField(useAvg)).
@@ -398,7 +406,11 @@ func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId i
} }
// 总数量 // 总数量
total, err := this.Query(tx). var totalQuery = this.Query(tx)
if noPlan {
totalQuery.Attr("userPlanId", 0)
}
total, err := totalQuery.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Between("day", month+"01", month+"31"). Between("day", month+"01", month+"31").
@@ -406,7 +418,7 @@ func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId i
if err != nil { if err != nil {
return 0, err return 0, err
} }
if total == 0 { if total == 0 || total < int64(minSamples) {
return 0, nil return 0, nil
} }
@@ -417,7 +429,11 @@ func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId i
} }
// 查询 nth 位置 // 查询 nth 位置
result, err = this.Query(tx). var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
result, err = query.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Result(this.bytesField(useAvg)). Result(this.bytesField(useAvg)).
@@ -713,7 +729,7 @@ func (this *ServerBandwidthStatDAO) SumDailyStat(tx *dbs.Tx, serverId int64, reg
var query = this.Query(tx). var query = this.Query(tx).
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Result("SUM(totalBytes) AS totalBytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes") Result("SUM(totalBytes) AS totalBytes, SUM(cachedBytes) AS cachedBytes, SUM(countRequests) AS countRequests, SUM(countCachedRequests) AS countCachedRequests, SUM(countAttackRequests) AS countAttackRequests, SUM(attackBytes) AS attackBytes, SUM(countIPs) AS countIPs")
query.Attr("serverId", serverId) query.Attr("serverId", serverId)
@@ -742,9 +758,78 @@ func (this *ServerBandwidthStatDAO) SumDailyStat(tx *dbs.Tx, serverId int64, reg
stat.CountCachedRequests = one.GetInt64("countCachedRequests") stat.CountCachedRequests = one.GetInt64("countCachedRequests")
stat.CountAttackRequests = one.GetInt64("countAttackRequests") stat.CountAttackRequests = one.GetInt64("countAttackRequests")
stat.AttackBytes = one.GetInt64("attackBytes") stat.AttackBytes = one.GetInt64("attackBytes")
stat.CountIPs = one.GetInt64("countIPs")
return return
} }
// SumMonthlyBytes 统计某个网站单月总流量
func (this *ServerBandwidthStatDAO) SumMonthlyBytes(tx *dbs.Tx, serverId int64, month string, noPlan bool) (int64, error) {
if !regexputils.YYYYMM.MatchString(month) {
return 0, errors.New("invalid month '" + month + "'")
}
// 兼容以往版本
hasFullData, err := this.HasFullData(tx, serverId, month)
if err != nil {
return 0, err
}
if !hasFullData {
return SharedServerDailyStatDAO.SumMonthlyBytes(tx, serverId, month)
}
var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
return query.
Table(this.partialTable(serverId)).
Between("day", month+"01", month+"31").
Attr("serverId", serverId).
SumInt64("totalBytes", 0)
}
// SumServerMonthlyWithRegion 根据服务计算某月合计
// month 格式为YYYYMM
func (this *ServerBandwidthStatDAO) SumServerMonthlyWithRegion(tx *dbs.Tx, serverId int64, regionId int64, month string, noPlan bool) (int64, error) {
var query = this.Query(tx)
query.Table(this.partialTable(serverId))
if regionId > 0 {
query.Attr("regionId", regionId)
}
if noPlan {
query.Attr("userPlanId", 0)
}
return query.Between("day", month+"01", month+"31").
Attr("serverId", serverId).
SumInt64("totalBytes", 0)
}
// FindDistinctServerIdsWithoutPlanAtPartition 查找没有绑定套餐的有流量网站
func (this *ServerBandwidthStatDAO) FindDistinctServerIdsWithoutPlanAtPartition(tx *dbs.Tx, partitionIndex int, month string) (serverIds []int64, err error) {
ones, err := this.Query(tx).
Table(this.partialTable(int64(partitionIndex))).
Between("day", month+"01", month+"31").
Attr("userPlanId", 0). // 没有绑定套餐
Result("DISTINCT serverId").
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
var serverId = int64(one.(*ServerBandwidthStat).ServerId)
if serverId <= 0 {
continue
}
serverIds = append(serverIds, serverId)
}
return
}
// CountPartitions 查看分区数量
func (this *ServerBandwidthStatDAO) CountPartitions() int {
return ServerBandwidthStatTablePartitions
}
// CleanDays 清理过期数据 // CleanDays 清理过期数据
func (this *ServerBandwidthStatDAO) CleanDays(tx *dbs.Tx, days int) error { func (this *ServerBandwidthStatDAO) CleanDays(tx *dbs.Tx, days int) error {
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) // 保留大约3个月的数据 var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) // 保留大约3个月的数据
@@ -777,9 +862,9 @@ func (this *ServerBandwidthStatDAO) CleanDefaultDays(tx *dbs.Tx, defaultDays int
func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.Mutex) error) error { func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.Mutex) error) error {
var locker = &sync.Mutex{} var locker = &sync.Mutex{}
var wg = sync.WaitGroup{} var wg = sync.WaitGroup{}
wg.Add(ServerBandwidthStatTablePartials) wg.Add(ServerBandwidthStatTablePartitions)
var resultErr error var resultErr error
for i := 0; i < ServerBandwidthStatTablePartials; i++ { for i := 0; i < ServerBandwidthStatTablePartitions; i++ {
var table = this.partialTable(int64(i)) var table = this.partialTable(int64(i))
go func(table string) { go func(table string) {
defer wg.Done() defer wg.Done()
@@ -796,7 +881,7 @@ func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.M
// 获取分区表 // 获取分区表
func (this *ServerBandwidthStatDAO) partialTable(serverId int64) string { func (this *ServerBandwidthStatDAO) partialTable(serverId int64) string {
return this.Table + "_" + types.String(serverId%int64(ServerBandwidthStatTablePartials)) return this.Table + "_" + types.String(serverId%int64(ServerBandwidthStatTablePartitions))
} }
// 获取字节字段 // 获取字节字段
@@ -844,6 +929,11 @@ func (this *ServerBandwidthStatDAO) fixServerStats(stats []*ServerBandwidthStat,
// HasFullData 检查一个月是否完整数据 // HasFullData 检查一个月是否完整数据
// 是为了兼容以前数据,以前的表中没有缓存流量、请求数等字段 // 是为了兼容以前数据,以前的表中没有缓存流量、请求数等字段
func (this *ServerBandwidthStatDAO) HasFullData(tx *dbs.Tx, serverId int64, month string) (bool, error) { func (this *ServerBandwidthStatDAO) HasFullData(tx *dbs.Tx, serverId int64, month string) (bool, error) {
// 最迟在2024年完成过渡
if time.Now().Year() >= 2024 {
return true, nil
}
var monthKey = month + "@" + types.String(serverId) var monthKey = month + "@" + types.String(serverId)
if !regexputils.YYYYMM.MatchString(month) { if !regexputils.YYYYMM.MatchString(month) {

View File

@@ -16,7 +16,7 @@ import (
func TestServerBandwidthStatDAO_UpdateServerBandwidth(t *testing.T) { func TestServerBandwidthStatDAO_UpdateServerBandwidth(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO() var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx var tx *dbs.Tx
err := dao.UpdateServerBandwidth(tx, 1, 1, 0, timeutil.Format("Ymd"), timeutil.FormatTime("Hi", time.Now().Unix()/300*300), 1024, 300, 0, 0, 0, 0, 0) err := dao.UpdateServerBandwidth(tx, 1, 1, 0, 0, timeutil.Format("Ymd"), timeutil.FormatTime("Hi", time.Now().Unix()/300*300), 1024, 300, 0, 0, 0, 0, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -33,7 +33,7 @@ func TestSeverBandwidthStatDAO_InsertManyStats(t *testing.T) {
} }
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -rands.Int(0, 200))) var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -rands.Int(0, 200)))
var minute = fmt.Sprintf("%02d%02d", rands.Int(0, 23), rands.Int(0, 59)) var minute = fmt.Sprintf("%02d%02d", rands.Int(0, 23), rands.Int(0, 59))
err := dao.UpdateServerBandwidth(tx, 1, int64(rands.Int(1, 10000)), 0, day, minute, 1024, 300, 0, 0, 0, 0, 0) err := dao.UpdateServerBandwidth(tx, 1, int64(rands.Int(1, 10000)), 0, 0, day, minute, 1024, 300, 0, 0, 0, 0, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -44,8 +44,10 @@ func TestSeverBandwidthStatDAO_InsertManyStats(t *testing.T) {
func TestServerBandwidthStatDAO_FindMonthlyPercentile(t *testing.T) { func TestServerBandwidthStatDAO_FindMonthlyPercentile(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO() var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx var tx *dbs.Tx
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, false)) t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, false, false, 0))
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true)) t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true, false, 0))
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true, false, 100))
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true, true, 0))
} }
func TestServerBandwidthStatDAO_FindAllServerStatsWithMonth(t *testing.T) { func TestServerBandwidthStatDAO_FindAllServerStatsWithMonth(t *testing.T) {
@@ -114,3 +116,32 @@ func TestServerBandwidthStatDAO_FindBandwidthStatsBetweenDays(t *testing.T) {
t.Log(stat.Day, stat.TimeAt, "bytes:", stat.Bytes, "bits:", stat.Bits) t.Log(stat.Day, stat.TimeAt, "bytes:", stat.Bytes, "bits:", stat.Bits)
} }
} }
func TestServerBandwidthStatDAO_SumServerMonthlyWithRegion(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx
{
totalBytes, err := dao.SumServerMonthlyWithRegion(tx, 23, 0, timeutil.Format("Ym"), false)
if err != nil {
t.Fatal(err)
}
t.Log("with plan:", totalBytes)
}
{
totalBytes, err := dao.SumServerMonthlyWithRegion(tx, 23, 0, timeutil.Format("Ym"), true)
if err != nil {
t.Fatal(err)
}
t.Log("without plan:", totalBytes)
}
}
func TestServerBandwidthStatDAO_SumMonthlyBytes(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx
totalBytes, err := dao.SumMonthlyBytes(tx, 23, timeutil.Format("Ym"), false)
if err != nil {
t.Fatal(err)
}
t.Log("total bytes:", totalBytes)
}

View File

@@ -1,11 +1,33 @@
package models package models
import "github.com/iwind/TeaGo/dbs"
const (
ServerBandwidthStatField_Id dbs.FieldName = "id" // ID
ServerBandwidthStatField_UserId dbs.FieldName = "userId" // 用户ID
ServerBandwidthStatField_ServerId dbs.FieldName = "serverId" // 服务ID
ServerBandwidthStatField_RegionId dbs.FieldName = "regionId" // 区域ID
ServerBandwidthStatField_UserPlanId dbs.FieldName = "userPlanId" // 用户套餐ID
ServerBandwidthStatField_Day dbs.FieldName = "day" // 日期YYYYMMDD
ServerBandwidthStatField_TimeAt dbs.FieldName = "timeAt" // 时间点HHMM
ServerBandwidthStatField_Bytes dbs.FieldName = "bytes" // 带宽字节
ServerBandwidthStatField_AvgBytes dbs.FieldName = "avgBytes" // 平均流量
ServerBandwidthStatField_CachedBytes dbs.FieldName = "cachedBytes" // 缓存的流量
ServerBandwidthStatField_AttackBytes dbs.FieldName = "attackBytes" // 攻击流量
ServerBandwidthStatField_CountRequests dbs.FieldName = "countRequests" // 请求数
ServerBandwidthStatField_CountCachedRequests dbs.FieldName = "countCachedRequests" // 缓存的请求数
ServerBandwidthStatField_CountAttackRequests dbs.FieldName = "countAttackRequests" // 攻击请求数
ServerBandwidthStatField_TotalBytes dbs.FieldName = "totalBytes" // 总流量
ServerBandwidthStatField_CountIPs dbs.FieldName = "countIPs" // 独立IP
)
// ServerBandwidthStat 服务峰值带宽统计 // ServerBandwidthStat 服务峰值带宽统计
type ServerBandwidthStat struct { type ServerBandwidthStat struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
UserId uint64 `field:"userId"` // 用户ID UserId uint64 `field:"userId"` // 用户ID
ServerId uint64 `field:"serverId"` // 服务ID ServerId uint64 `field:"serverId"` // 服务ID
RegionId uint32 `field:"regionId"` // 区域ID RegionId uint32 `field:"regionId"` // 区域ID
UserPlanId uint64 `field:"userPlanId"` // 用户套餐ID
Day string `field:"day"` // 日期YYYYMMDD Day string `field:"day"` // 日期YYYYMMDD
TimeAt string `field:"timeAt"` // 时间点HHMM TimeAt string `field:"timeAt"` // 时间点HHMM
Bytes uint64 `field:"bytes"` // 带宽字节 Bytes uint64 `field:"bytes"` // 带宽字节
@@ -16,6 +38,7 @@ type ServerBandwidthStat struct {
CountCachedRequests uint64 `field:"countCachedRequests"` // 缓存的请求数 CountCachedRequests uint64 `field:"countCachedRequests"` // 缓存的请求数
CountAttackRequests uint64 `field:"countAttackRequests"` // 攻击请求数 CountAttackRequests uint64 `field:"countAttackRequests"` // 攻击请求数
TotalBytes uint64 `field:"totalBytes"` // 总流量 TotalBytes uint64 `field:"totalBytes"` // 总流量
CountIPs uint64 `field:"countIPs"` // 独立IP
} }
type ServerBandwidthStatOperator struct { type ServerBandwidthStatOperator struct {
@@ -23,6 +46,7 @@ type ServerBandwidthStatOperator struct {
UserId any // 用户ID UserId any // 用户ID
ServerId any // 服务ID ServerId any // 服务ID
RegionId any // 区域ID RegionId any // 区域ID
UserPlanId any // 用户套餐ID
Day any // 日期YYYYMMDD Day any // 日期YYYYMMDD
TimeAt any // 时间点HHMM TimeAt any // 时间点HHMM
Bytes any // 带宽字节 Bytes any // 带宽字节
@@ -33,6 +57,7 @@ type ServerBandwidthStatOperator struct {
CountCachedRequests any // 缓存的请求数 CountCachedRequests any // 缓存的请求数
CountAttackRequests any // 攻击请求数 CountAttackRequests any // 攻击请求数
TotalBytes any // 总流量 TotalBytes any // 总流量
CountIPs any // 独立IP
} }
func NewServerBandwidthStatOperator() *ServerBandwidthStatOperator { func NewServerBandwidthStatOperator() *ServerBandwidthStatOperator {

View File

@@ -119,7 +119,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
// 更新流量限制状态 // 更新流量限制状态
if stat.CheckTrafficLimiting { if stat.CheckTrafficLimiting {
trafficLimitConfig, err := SharedServerDAO.CalculateServerTrafficLimitConfig(tx, stat.ServerId, cacheMap) trafficLimitConfig, err := SharedServerDAO.FindServerTrafficLimitConfig(tx, stat.ServerId, cacheMap)
if err != nil { if err != nil {
return err return err
} }
@@ -129,7 +129,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
return err return err
} }
err = SharedServerDAO.UpdateServerTrafficLimitStatus(tx, trafficLimitConfig, stat.ServerId, false) err = SharedServerDAO.RenewServerTrafficLimitStatus(tx, trafficLimitConfig, stat.ServerId, false)
if err != nil { if err != nil {
return err return err
} }
@@ -140,6 +140,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
return nil return nil
} }
// SumCurrentDailyStat 查找当前时刻的数据统计 // SumCurrentDailyStat 查找当前时刻的数据统计
func (this *ServerDailyStatDAO) SumCurrentDailyStat(tx *dbs.Tx, serverId int64) (*ServerDailyStat, error) { func (this *ServerDailyStatDAO) SumCurrentDailyStat(tx *dbs.Tx, serverId int64) (*ServerDailyStat, error) {
var day = timeutil.Format("Ymd") var day = timeutil.Format("Ymd")
@@ -164,7 +165,7 @@ func (this *ServerDailyStatDAO) SumServerMonthlyWithRegion(tx *dbs.Tx, serverId
if regionId > 0 { if regionId > 0 {
query.Attr("regionId", regionId) query.Attr("regionId", regionId)
} }
return query.Between("day", month+"01", month+"32"). return query.Between("day", month+"01", month+"31").
Attr("serverId", serverId). Attr("serverId", serverId).
SumInt64("bytes", 0) SumInt64("bytes", 0)
} }
@@ -178,7 +179,7 @@ func (this *ServerDailyStatDAO) SumUserMonthlyWithoutPlan(tx *dbs.Tx, userId int
} }
return query. return query.
Attr("planId", 0). Attr("planId", 0).
Between("day", month+"01", month+"32"). Between("day", month+"01", month+"31").
Attr("userId", userId). Attr("userId", userId).
SumInt64("bytes", 0) SumInt64("bytes", 0)
} }
@@ -190,7 +191,7 @@ func (this *ServerDailyStatDAO) SumUserMonthlyPeek(tx *dbs.Tx, userId int64, reg
if regionId > 0 { if regionId > 0 {
query.Attr("regionId", regionId) query.Attr("regionId", regionId)
} }
max, err := query.Between("day", month+"01", month+"32"). max, err := query.Between("day", month+"01", month+"31").
Attr("userId", userId). Attr("userId", userId).
Max("bytes", 0) Max("bytes", 0)
if err != nil { if err != nil {
@@ -644,7 +645,7 @@ func (this *ServerDailyStatDAO) FindStatsBetweenDays(tx *dbs.Tx, userId int64, s
// month YYYYMM // month YYYYMM
func (this *ServerDailyStatDAO) FindMonthlyStatsWithPlan(tx *dbs.Tx, month string) (result []*ServerDailyStat, err error) { func (this *ServerDailyStatDAO) FindMonthlyStatsWithPlan(tx *dbs.Tx, month string) (result []*ServerDailyStat, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Between("day", month+"01", month+"32"). Between("day", month+"01", month+"31").
Gt("planId", 0). Gt("planId", 0).
Slice(&result). Slice(&result).
FindAll() FindAll()

View File

@@ -3,11 +3,13 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils" dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
@@ -156,7 +158,6 @@ func (this *ServerDAO) CreateServer(tx *dbs.Tx,
httpsJSON []byte, httpsJSON []byte,
tcpJSON []byte, tcpJSON []byte,
tlsJSON []byte, tlsJSON []byte,
unixJSON []byte,
udpJSON []byte, udpJSON []byte,
webId int64, webId int64,
reverseProxyJSON []byte, reverseProxyJSON []byte,
@@ -204,9 +205,6 @@ func (this *ServerDAO) CreateServer(tx *dbs.Tx,
if IsNotNull(tlsJSON) { if IsNotNull(tlsJSON) {
op.Tls = tlsJSON op.Tls = tlsJSON
} }
if IsNotNull(unixJSON) {
op.Unix = unixJSON
}
if IsNotNull(udpJSON) { if IsNotNull(udpJSON) {
op.Udp = udpJSON op.Udp = udpJSON
} }
@@ -755,14 +753,14 @@ func (this *ServerDAO) UpdateServerAuditing(tx *dbs.Tx, serverId int64, result *
return this.NotifyDNSUpdate(tx, serverId) return this.NotifyDNSUpdate(tx, serverId)
} }
// UpdateServerReverseProxy 修改反向代理配置 // UpdateServerReverseProxyRef 修改反向代理配置
func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, config []byte) error { func (this *ServerDAO) UpdateServerReverseProxyRef(tx *dbs.Tx, serverId int64, reverseProxyRefJSON []byte) error {
if serverId <= 0 { if serverId <= 0 {
return errors.New("serverId should not be smaller than 0") return errors.New("serverId should not be smaller than 0")
} }
var op = NewServerOperator() var op = NewServerOperator()
op.Id = serverId op.Id = serverId
op.ReverseProxy = JSONBytes(config) op.ReverseProxy = JSONBytes(reverseProxyRefJSON)
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return err return err
@@ -771,6 +769,28 @@ func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, conf
return this.NotifyUpdate(tx, serverId) return this.NotifyUpdate(tx, serverId)
} }
// CreateServerReverseProxyRef 创建反向代理配置
func (this *ServerDAO) CreateServerReverseProxyRef(tx *dbs.Tx, userId int64, serverId int64) (reverseProxyId int64, err error) {
reverseProxyId, err = SharedReverseProxyDAO.CreateReverseProxy(tx, 0, userId, nil, []byte("[]"), []byte("[]"))
if err != nil {
return 0, err
}
var reverseProxyRef = &serverconfigs.ReverseProxyRef{
IsPrior: false,
IsOn: true,
ReverseProxyId: reverseProxyId,
}
reverseProxyRefJSON, err := json.Marshal(reverseProxyRef)
if err != nil {
return 0, err
}
err = this.UpdateServerReverseProxyRef(tx, serverId, reverseProxyRefJSON)
if err != nil {
return 0, err
}
return reverseProxyId, nil
}
// CountAllEnabledServers 计算所有可用服务数量 // CountAllEnabledServers 计算所有可用服务数量
func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) { func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) {
return this.Query(tx). return this.Query(tx).
@@ -782,7 +802,7 @@ func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) {
// 参数: // 参数:
// //
// groupId 分组ID如果为-1则搜索没有分组的服务 // groupId 分组ID如果为-1则搜索没有分组的服务
func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string) (int64, error) { func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string, userPlanId int64) (int64, error) {
query := this.Query(tx). query := this.Query(tx).
State(ServerStateEnabled) State(ServerStateEnabled)
if groupId > 0 { if groupId > 0 {
@@ -829,6 +849,10 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke
query.Where("(" + strings.Join(protocolConds, " OR ") + ")") query.Where("(" + strings.Join(protocolConds, " OR ") + ")")
} }
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.Count() return query.Count()
} }
@@ -1208,18 +1232,6 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
} }
} }
// Unix
if IsNotNull(server.Unix) {
var unixConfig = &serverconfigs.UnixProtocolConfig{}
err := json.Unmarshal(server.Unix, unixConfig)
if err != nil {
return nil, err
}
if !forNode || unixConfig.IsOn {
config.Unix = unixConfig
}
}
// UDP // UDP
if IsNotNull(server.Udp) { if IsNotNull(server.Udp) {
var udpConfig = &serverconfigs.UDPProtocolConfig{} var udpConfig = &serverconfigs.UDPProtocolConfig{}
@@ -1316,41 +1328,22 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
} }
// 套餐是否依然有效 // 套餐是否依然有效
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) config.UserPlan = &serverconfigs.UserPlanConfig{
if err != nil { Id: int64(userPlan.Id),
return nil, err DayTo: userPlan.DayTo,
} PlanId: int64(userPlan.PlanId),
if plan != nil {
config.UserPlan = &serverconfigs.UserPlanConfig{
DayTo: userPlan.DayTo,
Plan: &serverconfigs.PlanConfig{
Id: int64(plan.Id),
Name: plan.Name,
},
}
if len(plan.TrafficLimit) > 0 && (config.TrafficLimit == nil || !config.TrafficLimit.IsOn) {
var trafficLimitConfig = &serverconfigs.TrafficLimitConfig{}
err = json.Unmarshal(plan.TrafficLimit, trafficLimitConfig)
if err != nil {
return nil, err
}
config.TrafficLimit = trafficLimitConfig
}
} }
} }
} }
if config.TrafficLimit != nil && config.TrafficLimit.IsOn && !config.TrafficLimit.IsEmpty() { if len(server.TrafficLimitStatus) > 0 {
if len(server.TrafficLimitStatus) > 0 { var status = &serverconfigs.TrafficLimitStatus{}
var status = &serverconfigs.TrafficLimitStatus{} err := json.Unmarshal(server.TrafficLimitStatus, status)
err := json.Unmarshal(server.TrafficLimitStatus, status) if err != nil {
if err != nil { return nil, err
return nil, err }
} if status.IsValid() {
if status.IsValid() { config.TrafficLimitStatus = status
config.TrafficLimitStatus = status
}
} }
} }
@@ -1375,8 +1368,8 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
return config, nil return config, nil
} }
// FindReverseProxyRef 根据条件获取反向代理配置 // FindServerReverseProxyRef 根据条件获取反向代理配置
func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) { func (this *ServerDAO) FindServerReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) {
reverseProxy, err := this.Query(tx). reverseProxy, err := this.Query(tx).
Pk(serverId). Pk(serverId).
Result("reverseProxy"). Result("reverseProxy").
@@ -1387,7 +1380,7 @@ func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverc
if len(reverseProxy) == 0 || reverseProxy == "null" { if len(reverseProxy) == 0 || reverseProxy == "null" {
return nil, nil return nil, nil
} }
config := &serverconfigs.ReverseProxyRef{} var config = &serverconfigs.ReverseProxyRef{}
err = json.Unmarshal([]byte(reverseProxy), config) err = json.Unmarshal([]byte(reverseProxy), config)
return config, err return config, err
} }
@@ -1794,6 +1787,7 @@ func (this *ServerDAO) FindServerUserId(tx *dbs.Tx, serverId int64) (userId int6
} }
// FindServerUserPlanId 查找服务的套餐ID // FindServerUserPlanId 查找服务的套餐ID
// TODO 需要缓存
func (this *ServerDAO) FindServerUserPlanId(tx *dbs.Tx, serverId int64) (userPlanId int64, err error) { func (this *ServerDAO) FindServerUserPlanId(tx *dbs.Tx, serverId int64) (userPlanId int64, err error) {
return this.Query(tx). return this.Query(tx).
Pk(serverId). Pk(serverId).
@@ -2039,11 +2033,12 @@ func (this *ServerDAO) GenDNSName(tx *dbs.Tx) (string, error) {
// FindLatestServers 查询最近访问的服务 // FindLatestServers 查询最近访问的服务
func (this *ServerDAO) FindLatestServers(tx *dbs.Tx, size int64) (result []*Server, err error) { func (this *ServerDAO) FindLatestServers(tx *dbs.Tx, size int64) (result []*Server, err error) {
itemTable := SharedLatestItemDAO.Table var itemTable = SharedLatestItemDAO.Table
itemType := LatestItemTypeServer var itemType = LatestItemTypeServer
_, err = this.Query(tx). _, err = this.Query(tx).
Result(this.Table+".id", this.Table+".name"). Result(this.Table+".id", this.Table+".name").
Join(SharedLatestItemDAO, dbs.QueryJoinRight, this.Table+".id="+itemTable+".itemId AND "+itemTable+".itemType='"+itemType+"'"). Join(SharedLatestItemDAO, dbs.QueryJoinRight, this.Table+".id="+itemTable+".itemId AND "+itemTable+".itemType='"+itemType+"'").
Where(itemTable + ".updatedAt<=UNIX_TIMESTAMP()"). // VERY IMPORTANT
Asc("CEIL((UNIX_TIMESTAMP() - " + itemTable + ".updatedAt) / (7 * 86400))"). // 优先一个星期以内的 Asc("CEIL((UNIX_TIMESTAMP() - " + itemTable + ".updatedAt) / (7 * 86400))"). // 优先一个星期以内的
Desc(itemTable + ".count"). Desc(itemTable + ".count").
State(NodeClusterStateEnabled). State(NodeClusterStateEnabled).
@@ -2306,94 +2301,17 @@ func (this *ServerDAO) FindServerTrafficLimitConfig(tx *dbs.Tx, serverId int64,
return nil, err return nil, err
} }
var limit = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil {
return limit, nil
}
var trafficLimit = serverOne.(*Server).TrafficLimit
if len(trafficLimit) > 0 {
err = json.Unmarshal([]byte(trafficLimit), limit)
if err != nil {
return nil, err
}
}
if cacheMap != nil {
cacheMap.Put(cacheKey, limit)
}
return limit, nil
}
// CalculateServerTrafficLimitConfig 计算服务的流量限制
// TODO 优化性能
func (this *ServerDAO) CalculateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64, cacheMap *utils.CacheMap) (*serverconfigs.TrafficLimitConfig, error) {
if cacheMap == nil {
cacheMap = utils.NewCacheMap()
}
var cacheKey = this.Table + ":FindServerTrafficLimitConfig:" + types.String(serverId)
result, ok := cacheMap.Get(cacheKey)
if ok {
return result.(*serverconfigs.TrafficLimitConfig), nil
}
serverOne, err := this.Query(tx).
Pk(serverId).
Result("trafficLimit", "userPlanId").
Find()
if err != nil {
return nil, err
}
var limitConfig = &serverconfigs.TrafficLimitConfig{} var limitConfig = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil { if serverOne == nil {
return limitConfig, nil return limitConfig, nil
} }
var trafficLimit = serverOne.(*Server).TrafficLimit var trafficLimitJSON = serverOne.(*Server).TrafficLimit
var userPlanId = int64(serverOne.(*Server).UserPlanId)
if len(trafficLimit) == 0 { if len(trafficLimitJSON) > 0 {
if userPlanId > 0 { err = json.Unmarshal(trafficLimitJSON, limitConfig)
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap) if err != nil {
if err != nil { return nil, err
return nil, err
}
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
}
return limitConfig, nil
}
err = json.Unmarshal(trafficLimit, limitConfig)
if err != nil {
return nil, err
}
if !limitConfig.IsOn {
if userPlanId > 0 {
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap)
if err != nil {
return nil, err
}
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
} }
} }
@@ -2423,13 +2341,41 @@ func (this *ServerDAO) UpdateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64
} }
// 更新状态 // 更新状态
return this.UpdateServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true) return this.RenewServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true)
} }
// UpdateServerTrafficLimitStatus 修改服务的流量限制状态 // RenewServerTrafficLimitStatus 根据限流配置更新网站的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error { func (this *ServerDAO) RenewServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error {
if serverId <= 0 {
return nil
}
if !trafficLimitConfig.IsOn { if !trafficLimitConfig.IsOn {
if isUpdatingConfig { if isUpdatingConfig {
var oldStatus = &serverconfigs.TrafficLimitStatus{}
trafficLimitStatus, err := this.Query(tx).
Pk(serverId).
Result("trafficLimitStatus").
FindJSONCol()
if err != nil {
return err
}
if IsNotNull(trafficLimitStatus) {
err = json.Unmarshal(trafficLimitStatus, oldStatus)
if err != nil {
return err
}
if oldStatus.PlanId == 0 /** 说明是网站自行设置的限制 **/ {
err = this.Query(tx).
Pk(serverId).
Set("trafficLimitStatus", dbs.SQL("NULL")).
UpdateQuickly()
if err != nil {
return err
}
}
}
return this.NotifyUpdate(tx, serverId) return this.NotifyUpdate(tx, serverId)
} }
return nil return nil
@@ -2464,9 +2410,11 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
var untilDay = "" var untilDay = ""
// daily // daily
var dateType = ""
if trafficLimitConfig.DailyBytes() > 0 { if trafficLimitConfig.DailyBytes() > 0 {
if server.TrafficDay == timeutil.Format("Ymd") && server.TotalDailyTraffic >= float64(trafficLimitConfig.DailyBytes())/(1<<30) { if server.TrafficDay == timeutil.Format("Ymd") && server.TotalDailyTraffic >= float64(trafficLimitConfig.DailyBytes())/(1<<30) {
untilDay = timeutil.Format("Ymd") untilDay = timeutil.Format("Ymd")
dateType = "day"
} }
} }
@@ -2474,6 +2422,7 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if server.TrafficMonth == timeutil.Format("Ym") && trafficLimitConfig.MonthlyBytes() > 0 { if server.TrafficMonth == timeutil.Format("Ym") && trafficLimitConfig.MonthlyBytes() > 0 {
if server.TotalMonthlyTraffic >= float64(trafficLimitConfig.MonthlyBytes())/(1<<30) { if server.TotalMonthlyTraffic >= float64(trafficLimitConfig.MonthlyBytes())/(1<<30) {
untilDay = timeutil.Format("Ym32") untilDay = timeutil.Format("Ym32")
dateType = "month"
} }
} }
@@ -2481,12 +2430,17 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if trafficLimitConfig.TotalBytes() > 0 { if trafficLimitConfig.TotalBytes() > 0 {
if server.TotalTraffic >= float64(trafficLimitConfig.TotalBytes())/(1<<30) { if server.TotalTraffic >= float64(trafficLimitConfig.TotalBytes())/(1<<30) {
untilDay = "30000101" untilDay = "30000101"
dateType = "total"
} }
} }
var isChanged = oldStatus.UntilDay != untilDay var isChanged = oldStatus.UntilDay != untilDay
if isChanged { if isChanged {
statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{UntilDay: untilDay}) statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{
UntilDay: untilDay,
DateType: dateType,
TargetType: serverconfigs.TrafficLimitTargetTraffic,
})
if err != nil { if err != nil {
return err return err
} }
@@ -2507,6 +2461,93 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
return nil return nil
} }
// UpdateServerTrafficLimitStatus 修改网站的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, serverId int64, day string, planId int64, dateType string, targetType string) error {
if !regexputils.YYYYMMDD.MatchString(day) {
return errors.New("invalid 'day' format")
}
if serverId <= 0 {
return nil
}
// lookup old status
statusJSON, err := this.Query(tx).
Pk(serverId).
Result(ServerField_TrafficLimitStatus).
FindJSONCol()
if err != nil {
return err
}
if IsNotNull(statusJSON) {
var oldStatus = &serverconfigs.TrafficLimitStatus{}
err = json.Unmarshal(statusJSON, oldStatus)
if err != nil {
return err
}
if len(oldStatus.UntilDay) > 0 &&
oldStatus.UntilDay >= day /** 如果已经限制,且比当前日期长,则无需重复 **/ &&
oldStatus.PlanId == planId {
// no need to change
return nil
}
}
var status = &serverconfigs.TrafficLimitStatus{
UntilDay: day,
PlanId: planId,
DateType: dateType,
TargetType: targetType,
}
statusJSON, err = json.Marshal(status)
if err != nil {
return err
}
err = this.Query(tx).
Pk(serverId).
Set(ServerField_TrafficLimitStatus, statusJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
// UpdateServersTrafficLimitStatusWithUserPlanId 修改某个套餐下的网站的流量限制状态
func (this *ServerDAO) UpdateServersTrafficLimitStatusWithUserPlanId(tx *dbs.Tx, userPlanId int64, day string, planId int64, dateType string, targetType serverconfigs.TrafficLimitTarget) error {
if userPlanId <= 0 {
return nil
}
servers, err := this.Query(tx).
State(ServerStateEnabled).
Attr("userPlanId", userPlanId).
ResultPk().
FindAll()
if err != nil {
return err
}
for _, server := range servers {
var serverId = int64(server.(*Server).Id)
err = this.UpdateServerTrafficLimitStatus(tx, serverId, day, planId, dateType, targetType)
if err != nil {
return err
}
}
return nil
}
// ResetServersTrafficLimitStatusWithPlanId 重置某个套餐相关网站限流状态
func (this *ServerDAO) ResetServersTrafficLimitStatusWithPlanId(tx *dbs.Tx, planId int64) error {
return this.Query(tx).
Where("JSON_EXTRACT(trafficLimitStatus, '$.planId')=:planId").
Param("planId", planId).
Set("trafficLimitStatus", dbs.SQL("NULL")).
UpdateQuickly()
}
// IncreaseServerTotalTraffic 增加服务的总流量 // IncreaseServerTotalTraffic 增加服务的总流量
func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error { func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error {
if serverId <= 0 { if serverId <= 0 {
@@ -2548,17 +2589,16 @@ func (this *ServerDAO) FindEnabledServerIdWithUserPlanId(tx *dbs.Tx, userPlanId
FindInt64Col(0) FindInt64Col(0)
} }
// FindEnabledServerWithUserPlanId 查找使用某个套餐的服务 // FindEnabledServersWithUserPlanId 查找使用某个套餐的网站
func (this *ServerDAO) FindEnabledServerWithUserPlanId(tx *dbs.Tx, userPlanId int64) (*Server, error) { func (this *ServerDAO) FindEnabledServersWithUserPlanId(tx *dbs.Tx, userPlanId int64) (result []*Server, err error) {
one, err := this.Query(tx). _, err = this.Query(tx).
State(ServerStateEnabled). State(ServerStateEnabled).
Attr("userPlanId", userPlanId). Attr("userPlanId", userPlanId).
Result("id", "name", "serverNames", "type"). Result("id", "name", "serverNames", "type").
Find() AscPk().
if err != nil || one == nil { Slice(&result).
return nil, err FindAll()
} return
return one.(*Server), nil
} }
// UpdateServersClusterIdWithPlanId 修改套餐所在集群 // UpdateServersClusterIdWithPlanId 修改套餐所在集群
@@ -2576,13 +2616,17 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
return errors.New("serverId should not be smaller than 0") return errors.New("serverId should not be smaller than 0")
} }
oldClusterId, err := this.Query(tx). oldServerOne, queryErr := SharedServerDAO.
Query(tx).
Pk(serverId). Pk(serverId).
Result("clusterId"). Result("clusterId", "userPlanId").
FindInt64Col(0) Find()
if err != nil { if queryErr != nil || oldServerOne == nil {
return err return queryErr
} }
var oldServer = oldServerOne.(*Server)
var oldClusterId = int64(oldServer.ClusterId)
var oldUserPlanId = int64(oldServer.UserPlanId)
// 取消套餐 // 取消套餐
if userPlanId <= 0 { if userPlanId <= 0 {
@@ -2614,6 +2658,15 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
if err != nil { if err != nil {
return err return err
} }
// 重置以往的用户套餐状态
if oldUserPlanId > 0 {
err = SharedUserPlanStatDAO.ResetUserPlanStatsWithUserPlanId(tx, oldUserPlanId)
if err != nil {
return err
}
}
err = this.NotifyUpdate(tx, serverId) err = this.NotifyUpdate(tx, serverId)
if err != nil { if err != nil {
return err return err
@@ -2643,7 +2696,7 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
return errors.New("can not find user plan with id '" + types.String(userPlanId) + "'") return errors.New("can not find user plan with id '" + types.String(userPlanId) + "'")
} }
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil { if err != nil {
return err return err
} }
@@ -2661,6 +2714,21 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
if err != nil { if err != nil {
return err return err
} }
// 重置以往的用户套餐统计状态
if oldUserPlanId > 0 {
err = SharedUserPlanStatDAO.ResetUserPlanStatsWithUserPlanId(tx, oldUserPlanId)
if err != nil {
return err
}
}
// 重置当前用户套餐统计状态
err = SharedUserPlanStatDAO.ResetUserPlanStatsWithUserPlanId(tx, userPlanId)
if err != nil {
return err
}
err = this.NotifyUpdate(tx, serverId) err = this.NotifyUpdate(tx, serverId)
if err != nil { if err != nil {
return err return err
@@ -2881,6 +2949,100 @@ func (this *ServerDAO) FindEnabledServersWithIds(tx *dbs.Tx, serverIds []int64)
return return
} }
// CountAllServerNamesWithUserId 计算某个用户下的所有域名数
func (this *ServerDAO) CountAllServerNamesWithUserId(tx *dbs.Tx, userId int64, userPlanId int64) (int64, error) {
if userId <= 0 {
return 0, nil
}
var query = this.Query(tx).
Attr("userId", userId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'")
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.
SumInt64("JSON_LENGTH(plainServerNames)", 0)
}
// CountServerNames 计算某个网站下的所有域名数
func (this *ServerDAO) CountServerNames(tx *dbs.Tx, serverId int64) (int64, error) {
if serverId <= 0 {
return 0, nil
}
return this.Query(tx).
Result("JSON_LENGTH(plainServerNames)").
Pk(serverId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'").
FindInt64Col(0)
}
// CheckServerPlanQuota 检查网站套餐限制
func (this *ServerDAO) CheckServerPlanQuota(tx *dbs.Tx, serverId int64, countServerNames int) error {
if serverId <= 0 {
return errors.New("invalid 'serverId'")
}
if countServerNames <= 0 {
return nil
}
userPlanId, err := this.FindServerUserPlanId(tx, serverId)
if err != nil {
return err
}
if userPlanId <= 0 {
return nil
}
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, nil)
if err != nil {
return err
}
if userPlan == nil {
return fmt.Errorf("invalid user plan with id %q", types.String(userPlanId))
}
if userPlan.IsExpired() {
return errors.New("the user plan has been expired")
}
if userPlan.UserId == 0 {
return nil
}
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil {
return err
}
if plan == nil {
return fmt.Errorf("invalid plan with id %q", types.String(userPlan.PlanId))
}
if plan.TotalServerNames > 0 {
totalServerNames, err := this.CountAllServerNamesWithUserId(tx, int64(userPlan.UserId), userPlanId)
if err != nil {
return err
}
if totalServerNames+int64(countServerNames) > int64(plan.TotalServerNames) {
return errors.New("server names over plan quota")
}
}
if plan.TotalServerNamesPerServer > 0 {
if countServerNames > types.Int(plan.TotalServerNamesPerServer) {
return errors.New("server names per server over plan quota")
}
}
return nil
}
// ExistsServer 检查网站是否存在
func (this *ServerDAO) ExistsServer(tx *dbs.Tx, serverId int64) (bool, error) {
if serverId <= 0 {
return false, nil
}
return this.Query(tx).
Pk(serverId).
State(ServerStateEnabled).
Exist()
}
// NotifyUpdate 同步服务所在的集群 // NotifyUpdate 同步服务所在的集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 { if serverId <= 0 {

View File

@@ -10,26 +10,26 @@ import (
) )
// CopyServerConfigToServers 拷贝服务配置到一组服务 // CopyServerConfigToServers 拷贝服务配置到一组服务
func (this *ServerDAO) CopyServerConfigToServers(tx *dbs.Tx, fromServerId int64, toServerIds []int64, configCode serverconfigs.ConfigCode) error { func (this *ServerDAO) CopyServerConfigToServers(tx *dbs.Tx, fromServerId int64, toServerIds []int64, configCode serverconfigs.ConfigCode, wafCopyRegions bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
// CopyServerConfigToGroups 拷贝服务配置到分组 // CopyServerConfigToGroups 拷贝服务配置到分组
func (this *ServerDAO) CopyServerConfigToGroups(tx *dbs.Tx, fromServerId int64, groupIds []int64, configCode string) error { func (this *ServerDAO) CopyServerConfigToGroups(tx *dbs.Tx, fromServerId int64, groupIds []int64, configCode string, wafCopyRegions bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
// CopyServerConfigToCluster 拷贝服务配置到集群 // CopyServerConfigToCluster 拷贝服务配置到集群
func (this *ServerDAO) CopyServerConfigToCluster(tx *dbs.Tx, fromServerId int64, clusterId int64, configCode string) error { func (this *ServerDAO) CopyServerConfigToCluster(tx *dbs.Tx, fromServerId int64, clusterId int64, configCode string, wafCopyRegions bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
// CopyServerConfigToUser 拷贝服务配置到用户 // CopyServerConfigToUser 拷贝服务配置到用户
func (this *ServerDAO) CopyServerConfigToUser(tx *dbs.Tx, fromServerId int64, userId int64, configCode string) error { func (this *ServerDAO) CopyServerConfigToUser(tx *dbs.Tx, fromServerId int64, userId int64, configCode string, wafCopyRegions bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }
// CopyServerUAMConfigs 复制UAM设置 // CopyServerUAMConfigs 复制UAM设置
func (this *ServerDAO) CopyServerUAMConfigs(tx *dbs.Tx, fromServerId int64, toServerIds []int64) error { func (this *ServerDAO) CopyServerUAMConfigs(tx *dbs.Tx, fromServerId int64, toServerIds []int64, wafCopyRegions bool) error {
return errors.New("not implemented") return errors.New("not implemented")
} }

View File

@@ -0,0 +1,12 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import "github.com/iwind/TeaGo/dbs"
// ResetServersTrafficLimitStatusWithUserPlanId 重置用户套餐相关网站限流状态
func (this *ServerDAO) ResetServersTrafficLimitStatusWithUserPlanId(tx *dbs.Tx, userPlanId int64) error {
// stub
return nil
}

View File

@@ -242,7 +242,7 @@ func TestServerDAO_FindEnabledServerWithDomain(t *testing.T) {
} }
} }
func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) { func TestServerDAO_RenewServerTrafficLimitStatus(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx var tx *dbs.Tx
@@ -250,7 +250,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
defer func() { defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms") t.Log(time.Since(before).Seconds()*1000, "ms")
}() }()
err := models.NewServerDAO().UpdateServerTrafficLimitStatus(tx, &serverconfigs.TrafficLimitConfig{ err := models.NewServerDAO().RenewServerTrafficLimitStatus(tx, &serverconfigs.TrafficLimitConfig{
IsOn: true, IsOn: true,
DailySize: &shared.SizeCapacity{Count: 1, Unit: "mb"}, DailySize: &shared.SizeCapacity{Count: 1, Unit: "mb"},
MonthlySize: &shared.SizeCapacity{Count: 10, Unit: "mb"}, MonthlySize: &shared.SizeCapacity{Count: 10, Unit: "mb"},
@@ -263,40 +263,15 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
t.Log("ok") t.Log("ok")
} }
func TestServerDAO_CalculateServerTrafficLimitConfig(t *testing.T) { func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var dao = models.NewServerDAO()
var tx *dbs.Tx var tx *dbs.Tx
before := time.Now() err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic")
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
var cacheMap = utils.NewCacheMap()
config, err := models.SharedServerDAO.CalculateServerTrafficLimitConfig(tx, 23, cacheMap)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
logs.PrintAsJSON(config, t)
}
func TestServerDAO_CalculateServerTrafficLimitConfig_Cache(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
before := time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
var cacheMap = utils.NewCacheMap()
for i := 0; i < 10; i++ {
config, err := models.SharedServerDAO.CalculateServerTrafficLimitConfig(tx, 23, cacheMap)
if err != nil {
t.Fatal(err)
}
_ = config
}
} }
func TestServerDAO_FindBytes(t *testing.T) { func TestServerDAO_FindBytes(t *testing.T) {

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
@@ -78,7 +79,7 @@ func (this *Server) DecodeHTTPSPorts() (ports []int) {
if err != nil { if err != nil {
return nil return nil
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil return nil
} }
@@ -120,7 +121,7 @@ func (this *Server) DecodeTLSPorts() (ports []int) {
if err != nil { if err != nil {
return nil return nil
} }
err = config.Init(nil) err = config.Init(context.TODO())
if err != nil { if err != nil {
return nil return nil
} }

Some files were not shown because too many files have changed in this diff Show More