Files
dns-server/shield/manager.go
Alex Yang 5a1d44c3b3 beta1
2025-11-23 19:35:02 +08:00

1052 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package shield
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
)
// regexRule 正则规则结构,包含编译后的表达式和原始字符串
type regexRule struct {
pattern *regexp.Regexp
original string
}
// ShieldStatsData 用于持久化的Shield统计数据
type ShieldStatsData struct {
BlockedDomainsCount map[string]int `json:"blockedDomainsCount"`
ResolvedDomainsCount map[string]int `json:"resolvedDomainsCount"`
LastSaved time.Time `json:"lastSaved"`
}
// ShieldManager 屏蔽管理器
type ShieldManager struct {
config *config.ShieldConfig
domainRules map[string]bool
domainExceptions map[string]bool
regexRules []regexRule
regexExceptions []regexRule
hostsMap map[string]string
blockedDomainsCount map[string]int
resolvedDomainsCount map[string]int
rulesMutex sync.RWMutex
updateCtx context.Context
updateCancel context.CancelFunc
updateRunning bool
}
// NewShieldManager 创建屏蔽管理器实例
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
ctx, cancel := context.WithCancel(context.Background())
manager := &ShieldManager{
config: config,
domainRules: make(map[string]bool),
domainExceptions: make(map[string]bool),
regexRules: []regexRule{},
regexExceptions: []regexRule{},
hostsMap: make(map[string]string),
blockedDomainsCount: make(map[string]int),
resolvedDomainsCount: make(map[string]int),
updateCtx: ctx,
updateCancel: cancel,
}
// 加载已保存的计数数据
manager.loadStatsData()
return manager
}
// LoadRules 加载屏蔽规则
func (m *ShieldManager) LoadRules() error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 清空现有规则
m.domainRules = make(map[string]bool)
m.domainExceptions = make(map[string]bool)
m.regexRules = []regexRule{}
m.regexExceptions = []regexRule{}
m.hostsMap = make(map[string]string)
// 保留计数数据,不随规则重新加载而清空
// 加载本地规则文件
if err := m.loadLocalRules(); err != nil {
logger.Error("加载本地规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载远程规则
if err := m.loadRemoteRules(); err != nil {
logger.Error("加载远程规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载hosts文件
if err := m.loadHosts(); err != nil {
logger.Error("加载hosts文件失败", "error", err)
// 继续执行,不返回错误
}
logger.Info(fmt.Sprintf("规则加载完成,域名规则: %d, 排除规则: %d, 正则规则: %d, hosts规则: %d",
len(m.domainRules), len(m.domainExceptions), len(m.regexRules), len(m.hostsMap)))
return nil
}
// loadLocalRules 加载本地规则文件
func (m *ShieldManager) loadLocalRules() error {
if m.config.LocalRulesFile == "" {
return nil
}
file, err := os.Open(m.config.LocalRulesFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return scanner.Err()
}
// loadRemoteRules 加载远程规则
func (m *ShieldManager) loadRemoteRules() error {
for _, url := range m.config.RemoteRules {
if err := m.fetchRemoteRules(url); err != nil {
logger.Error("获取远程规则失败", "url", url, "error", err)
continue
}
}
return nil
}
// getCacheFilePath 根据URL生成缓存文件路径
func (m *ShieldManager) getCacheFilePath(url string) string {
// 使用URL的哈希值作为文件名避免文件名冲突
hash := fmt.Sprintf("%x", url)
// 简单处理,移除特殊字符,确保文件名合法
hash = strings.ReplaceAll(hash, "/", "_")
hash = strings.ReplaceAll(hash, "\\", "_")
return filepath.Join(m.config.RemoteRulesCacheDir, hash+".rules")
}
// shouldUpdateCache 检查缓存是否需要更新
func (m *ShieldManager) shouldUpdateCache(cacheFile string) bool {
// 检查文件是否存在
if _, err := os.Stat(cacheFile); os.IsNotExist(err) {
return true
}
// 检查文件修改时间
fileInfo, err := os.Stat(cacheFile)
if err != nil {
return true
}
// 如果缓存文件超过更新间隔时间,则需要更新
return time.Since(fileInfo.ModTime()) > time.Duration(m.config.UpdateInterval)*time.Second
}
// fetchRemoteRules 从远程URL获取规则
func (m *ShieldManager) fetchRemoteRules(url string) error {
// 获取缓存文件路径
cacheFile := m.getCacheFilePath(url)
// 尝试从缓存加载
hasLoadedFromCache := false
if !m.shouldUpdateCache(cacheFile) {
if err := m.loadCachedRules(cacheFile); err == nil {
logger.Info("从缓存加载远程规则", "url", url)
hasLoadedFromCache = true
}
}
// 从远程获取规则
resp, err := http.Get(url)
if err != nil {
// 如果从远程获取失败但已经从缓存加载成功则返回nil
if hasLoadedFromCache {
logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "error", err)
return nil
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// 如果状态码不正确但已经从缓存加载成功则返回nil
if hasLoadedFromCache {
logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "statusCode", resp.StatusCode)
return nil
}
return fmt.Errorf("远程服务器返回错误状态码: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
// 保存规则到缓存
if err := m.saveRemoteRulesToCache(cacheFile, body); err != nil {
logger.Warn("保存远程规则缓存失败", "url", url, "error", err)
// 继续处理,不返回错误
}
// 解析并加载规则
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return nil
}
// loadCachedRules 从缓存文件加载规则
func (m *ShieldManager) loadCachedRules(filePath string) error {
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()
body, err := ioutil.ReadAll(file)
if err != nil {
return err
}
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return nil
}
// saveRemoteRulesToCache 保存远程规则到缓存文件
func (m *ShieldManager) saveRemoteRulesToCache(filePath string, data []byte) error {
// 确保缓存目录存在
if err := os.MkdirAll(m.config.RemoteRulesCacheDir, 0755); err != nil {
return err
}
// 写入文件
return ioutil.WriteFile(filePath, data, 0644)
}
// loadHosts 加载hosts文件
func (m *ShieldManager) loadHosts() error {
if m.config.HostsFile == "" {
return nil
}
file, err := os.Open(m.config.HostsFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
ip := parts[0]
for i := 1; i < len(parts); i++ {
m.hostsMap[parts[i]] = ip
}
}
}
return scanner.Err()
}
// parseRule 解析规则行
func (m *ShieldManager) parseRule(line string) {
// 处理注释
if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" {
return
}
// 移除规则选项部分(暂时不处理规则选项)
if strings.Contains(line, "$") {
parts := strings.SplitN(line, "$", 2)
line = parts[0]
// 规则选项暂时不处理
}
// 处理排除规则 (@@前缀表示取消屏蔽)
isException := false
if strings.HasPrefix(line, "@@") {
isException = true
line = strings.TrimPrefix(line, "@@")
}
// 处理不同类型的规则
switch {
case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"):
// AdGuardHome域名规则格式: ||example.com^
domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "||"):
// 精确域名匹配规则
domain := strings.TrimPrefix(line, "||")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "*"):
// 通配符规则,转换为正则表达式
pattern := strings.ReplaceAll(line, "*", ".*")
pattern = "^" + pattern + "$"
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"):
// 正则表达式规则
pattern := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/")
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"):
// 完整URL匹配规则
urlPattern := strings.TrimPrefix(strings.TrimSuffix(line, "|"), "|")
// 将URL模式转换为正则表达式
pattern := "^" + regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|"):
// URL开头匹配规则
urlPattern := strings.TrimPrefix(line, "|")
pattern := "^" + regexp.QuoteMeta(urlPattern)
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasSuffix(line, "|"):
// URL结尾匹配规则
urlPattern := strings.TrimSuffix(line, "|")
pattern := regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
default:
// 默认作为普通域名规则
m.addDomainRule(line, !isException)
}
}
// parseRuleOptions 解析规则选项
func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string {
options := make(map[string]string)
parts := strings.Split(optionsStr, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if strings.Contains(part, "=") {
kv := strings.SplitN(part, "=", 2)
options[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
} else {
options[part] = ""
}
}
return options
}
// addDomainRule 添加域名规则,支持是否为阻止规则
func (m *ShieldManager) addDomainRule(domain string, block bool) {
if block {
m.domainRules[domain] = true
// 添加所有子域名的匹配支持
parts := strings.Split(domain, ".")
if len(parts) > 1 {
// 为二级域名和顶级域名添加规则
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainRules[subdomain] = true
}
}
} else {
// 添加到排除规则
m.domainExceptions[domain] = true
// 为子域名也添加排除规则
parts := strings.Split(domain, ".")
if len(parts) > 1 {
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainExceptions[subdomain] = true
}
}
}
}
// addRegexRule 添加正则表达式规则,支持是否为阻止规则
func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool) {
rule := regexRule{
pattern: re,
original: original,
}
if block {
m.regexRules = append(m.regexRules, rule)
} else {
// 添加到排除规则
m.regexExceptions = append(m.regexExceptions, rule)
}
}
// IsBlocked 检查域名是否被屏蔽
func (m *ShieldManager) IsBlocked(domain string) bool {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 预处理域名,去除可能的端口号
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
domain = parts[0]
}
// 首先检查排除规则(优先级最高)
// 检查域名排除规则
if m.domainExceptions[domain] {
return false
}
// 检查子域名排除规则
parts := strings.Split(domain, ".")
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainExceptions[subdomain] {
return false
}
}
// 检查正则表达式排除规则
for _, re := range m.regexExceptions {
if re.pattern.MatchString(domain) {
return false
}
}
// 然后检查阻止规则
// 检查精确域名匹配
if m.domainRules[domain] {
return true
}
// 检查子域名匹配AdGuardHome风格
// 从最长的子域名开始匹配,确保优先级正确
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainRules[subdomain] {
return true
}
}
// 检查正则表达式匹配
for _, re := range m.regexRules {
if re.pattern.MatchString(domain) {
return true
}
}
return false
}
// RecordBlockedDomain 记录被屏蔽的域名
func (m *ShieldManager) RecordBlockedDomain(domain string) {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.blockedDomainsCount[domain]++
}
// RecordResolvedDomain 记录被解析的域名
func (m *ShieldManager) RecordResolvedDomain(domain string) {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.resolvedDomainsCount[domain]++
}
// GetTopBlockedDomains 获取最常被屏蔽的域名
func (m *ShieldManager) GetTopBlockedDomains(limit int) []map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 如果没有数据,返回空数组
if len(m.blockedDomainsCount) == 0 {
return []map[string]interface{}{}
}
// 转换为切片以便排序
type domainCount struct {
Domain string
Count int
}
var domains []domainCount
for domain, count := range m.blockedDomainsCount {
domains = append(domains, domainCount{Domain: domain, Count: count})
}
// 按计数降序排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 限制返回数量
if len(domains) > limit {
domains = domains[:limit]
}
// 转换为API响应格式
result := make([]map[string]interface{}, len(domains))
for i, item := range domains {
result[i] = map[string]interface{}{
"domain": item.Domain,
"count": item.Count,
}
}
return result
}
// GetTopResolvedDomains 获取最常被解析的域名
func (m *ShieldManager) GetTopResolvedDomains(limit int) []map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 如果没有数据,返回空数组
if len(m.resolvedDomainsCount) == 0 {
return []map[string]interface{}{}
}
// 转换为切片以便排序
type domainCount struct {
Domain string
Count int
}
var domains []domainCount
for domain, count := range m.resolvedDomainsCount {
domains = append(domains, domainCount{Domain: domain, Count: count})
}
// 按计数降序排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 限制返回数量
if len(domains) > limit {
domains = domains[:limit]
}
// 转换为API响应格式
result := make([]map[string]interface{}, len(domains))
for i, item := range domains {
result[i] = map[string]interface{}{
"domain": item.Domain,
"count": item.Count,
}
}
return result
}
// GetHostsIP 获取hosts文件中的IP映射
func (m *ShieldManager) GetHostsIP(domain string) (string, bool) {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
ip, exists := m.hostsMap[domain]
return ip, exists
}
// AddRule 添加屏蔽规则
func (m *ShieldManager) AddRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 解析并添加规则到内存
m.parseRule(rule)
// 持久化保存规则到文件
if m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveRule 删除屏蔽规则
func (m *ShieldManager) RemoveRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
removed := false
// 清理规则,移除可能的修饰符
cleanRule := rule
// 移除规则结束符
cleanRule = strings.TrimSuffix(cleanRule, "^")
// 尝试多种可能的规则格式
formatsToTry := []string{cleanRule}
// 根据规则类型添加可能的格式变体
if strings.HasPrefix(cleanRule, "@@||") {
// 已有的排除规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "@@||"))
} else if strings.HasPrefix(cleanRule, "||") {
// 已有的域名规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "||"))
} else {
// 可能是裸域名,尝试添加前缀
formatsToTry = append(formatsToTry, "||"+cleanRule, "@@||"+cleanRule)
}
// 尝试所有可能的格式变体来删除规则
for _, format := range formatsToTry {
if removed {
break
}
// 尝试删除排除规则
if strings.HasPrefix(format, "@@||") {
domain := strings.TrimPrefix(format, "@@||")
if _, exists := m.domainExceptions[domain]; exists {
delete(m.domainExceptions, domain)
removed = true
break
}
} else if strings.HasPrefix(format, "||") {
// 尝试删除域名规则
domain := strings.TrimPrefix(format, "||")
if _, exists := m.domainRules[domain]; exists {
delete(m.domainRules, domain)
removed = true
break
}
} else {
// 尝试直接作为域名删除
if _, exists := m.domainRules[format]; exists {
delete(m.domainRules, format)
removed = true
break
}
if _, exists := m.domainExceptions[format]; exists {
delete(m.domainExceptions, format)
removed = true
break
}
}
}
// 处理正则表达式规则
if !removed && strings.HasPrefix(cleanRule, "/") && strings.HasSuffix(cleanRule, "/") {
pattern := strings.TrimPrefix(strings.TrimSuffix(cleanRule, "/"), "/")
// 检查是否在正则表达式规则中
newRegexRules := []regexRule{}
for _, re := range m.regexRules {
if re.pattern.String() != pattern {
newRegexRules = append(newRegexRules, re)
} else {
removed = true
}
}
m.regexRules = newRegexRules
// 如果没有从正则规则中找到,检查是否在正则排除规则中
if !removed {
newRegexExceptions := []regexRule{}
for _, re := range m.regexExceptions {
if re.pattern.String() != pattern {
newRegexExceptions = append(newRegexExceptions, re)
} else {
removed = true
}
}
m.regexExceptions = newRegexExceptions
}
}
// 处理通配符和URL匹配规则
if !removed && (strings.HasPrefix(cleanRule, "*") || strings.HasSuffix(cleanRule, "*") || strings.HasPrefix(cleanRule, "|") || strings.HasSuffix(cleanRule, "|")) {
// 遍历所有规则,找到匹配的规则进行删除
for domain := range m.domainRules {
if domain == cleanRule || domain == rule {
delete(m.domainRules, domain)
removed = true
break
}
}
if !removed {
for domain := range m.domainExceptions {
if domain == cleanRule || domain == rule {
delete(m.domainExceptions, domain)
removed = true
break
}
}
}
}
// 如果有规则被删除,持久化保存更改
if removed && m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// StartAutoUpdate 启动自动更新
func (m *ShieldManager) StartAutoUpdate() {
if m.updateRunning {
return
}
m.updateRunning = true
go func() {
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
defer ticker.Stop()
// 启动自动保存计数数据
go m.startAutoSaveStats()
for {
select {
case <-ticker.C:
logger.Info("开始自动更新规则")
if err := m.LoadRules(); err != nil {
logger.Error("自动更新规则失败", "error", err)
} else {
logger.Info("自动更新规则成功")
}
case <-m.updateCtx.Done():
// 保存计数数据
m.saveStatsData()
m.updateRunning = false
return
}
}
}()
logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval)
// 如果是首次启动,先保存一次数据确保目录存在
go m.saveStatsData()
}
// StopAutoUpdate 停止自动更新
func (m *ShieldManager) StopAutoUpdate() {
m.updateRunning = false
m.updateCancel()
// 保存计数数据
m.saveStatsData()
logger.Info("规则自动更新已停止")
}
// saveRulesToFile 保存规则到文件
func (m *ShieldManager) saveRulesToFile() error {
var rules []string
// 添加域名规则
for domain := range m.domainRules {
rules = append(rules, "||"+domain)
}
// 添加正则表达式规则
for _, re := range m.regexRules {
rules = append(rules, re.original)
}
// 添加排除规则
for domain := range m.domainExceptions {
rules = append(rules, "@@||"+domain)
}
// 添加正则表达式排除规则
for _, re := range m.regexExceptions {
rules = append(rules, re.original)
}
// 写入文件
content := strings.Join(rules, "\n")
return ioutil.WriteFile(m.config.LocalRulesFile, []byte(content), 0644)
}
// AddHostsEntry 添加hosts条目
func (m *ShieldManager) AddHostsEntry(ip, domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.hostsMap[domain] = ip
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveHostsEntry 删除hosts条目
func (m *ShieldManager) RemoveHostsEntry(domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
if _, exists := m.hostsMap[domain]; exists {
delete(m.hostsMap, domain)
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
}
return nil
}
// saveHostsToFile 保存hosts到文件
func (m *ShieldManager) saveHostsToFile() error {
var lines []string
// 添加hosts头部注释
lines = append(lines, "# DNS Server Hosts File")
lines = append(lines, "# Generated by DNS Server")
lines = append(lines, "")
// 添加localhost条目如果不存在
if _, exists := m.hostsMap["localhost"]; !exists {
lines = append(lines, "127.0.0.1 localhost")
lines = append(lines, "::1 localhost")
lines = append(lines, "")
}
// 添加所有hosts条目
for domain, ip := range m.hostsMap {
lines = append(lines, ip+"\t"+domain)
}
// 写入文件
content := strings.Join(lines, "\n")
return ioutil.WriteFile(m.config.HostsFile, []byte(content), 0644)
}
// GetStats 获取规则统计信息
func (m *ShieldManager) GetStats() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
return map[string]interface{}{
"domainRules": len(m.domainRules),
"domainExceptions": len(m.domainExceptions),
"regexRules": len(m.regexRules),
"regexExceptions": len(m.regexExceptions),
"hostsRules": len(m.hostsMap),
"updateInterval": m.config.UpdateInterval,
}
}
// loadStatsData 从文件加载计数数据
func (m *ShieldManager) loadStatsData() {
if m.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(m.config.StatsFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取Shield计数数据文件失败", "error", err)
}
return
}
var statsData ShieldStatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析Shield计数数据失败", "error", err)
return
}
// 恢复计数数据
m.rulesMutex.Lock()
if statsData.BlockedDomainsCount != nil {
m.blockedDomainsCount = statsData.BlockedDomainsCount
}
if statsData.ResolvedDomainsCount != nil {
m.resolvedDomainsCount = statsData.ResolvedDomainsCount
}
m.rulesMutex.Unlock()
logger.Info("Shield计数数据加载成功")
}
// saveStatsData 保存计数数据到文件
func (m *ShieldManager) saveStatsData() {
if m.config.StatsFile == "" {
return
}
// 创建数据目录
statsDir := filepath.Dir(m.config.StatsFile)
err := os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建Shield统计数据目录失败", "error", err)
return
}
// 收集计数数据
m.rulesMutex.RLock()
statsData := &ShieldStatsData{
BlockedDomainsCount: make(map[string]int),
ResolvedDomainsCount: make(map[string]int),
LastSaved: time.Now(),
}
// 复制数据
for k, v := range m.blockedDomainsCount {
statsData.BlockedDomainsCount[k] = v
}
for k, v := range m.resolvedDomainsCount {
statsData.ResolvedDomainsCount[k] = v
}
m.rulesMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化Shield计数数据失败", "error", err)
return
}
// 写入文件
err = ioutil.WriteFile(m.config.StatsFile, jsonData, 0644)
if err != nil {
logger.Error("保存Shield计数数据到文件失败", "error", err)
return
}
logger.Info("Shield计数数据保存成功")
}
// startAutoSaveStats 启动计数数据自动保存功能
func (m *ShieldManager) startAutoSaveStats() {
if m.config.StatsFile == "" || m.config.StatsSaveInterval <= 0 {
return
}
ticker := time.NewTicker(time.Duration(m.config.StatsSaveInterval) * time.Second)
defer ticker.Stop()
logger.Info("启动Shield计数数据自动保存功能", "interval", m.config.StatsSaveInterval, "file", m.config.StatsFile)
// 定期保存数据
for {
select {
case <-ticker.C:
m.saveStatsData()
case <-m.updateCtx.Done():
return
}
}
}
// GetRules 获取所有规则
func (m *ShieldManager) GetRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表
domainRulesList := make([]string, 0, len(m.domainRules))
for domain := range m.domainRules {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
domainExceptionsList := make([]string, 0, len(m.domainExceptions))
for domain := range m.domainExceptions {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
// 获取正则规则原始字符串
regexRulesList := make([]string, 0, len(m.regexRules))
for _, re := range m.regexRules {
regexRulesList = append(regexRulesList, re.original)
}
// 获取正则排除规则原始字符串
regexExceptionsList := make([]string, 0, len(m.regexExceptions))
for _, re := range m.regexExceptions {
regexExceptionsList = append(regexExceptionsList, re.original)
}
// 获取hosts规则
hostsRulesList := make([]string, 0, len(m.hostsMap))
for domain, ip := range m.hostsMap {
hostsRulesList = append(hostsRulesList, ip+"\t"+domain)
}
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"hostsRules": hostsRulesList,
}
}