Files
dns-server/shield/manager.go
2025-11-30 02:25:36 +08:00

1470 lines
42 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package shield
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
)
// ShieldStatsData 用于持久化的Shield统计数据
type ShieldStatsData struct {
BlockedDomainsCount map[string]int `json:"blockedDomainsCount"`
ResolvedDomainsCount map[string]int `json:"resolvedDomainsCount"`
LastSaved time.Time `json:"lastSaved"`
}
// regexRule 正则规则结构,包含编译后的表达式和原始字符串
type regexRule struct {
pattern *regexp.Regexp
original string
isLocal bool // 是否为本地规则
source string // 规则来源
}
// ShieldManager 屏蔽管理器
type ShieldManager struct {
config *config.ShieldConfig
domainRules map[string]bool
domainExceptions map[string]bool
domainRulesIsLocal map[string]bool // 标记域名规则是否为本地规则
domainExceptionsIsLocal map[string]bool // 标记域名排除规则是否为本地规则
domainRulesSource map[string]string // 标记域名规则来源
domainExceptionsSource map[string]string // 标记域名排除规则来源
domainRulesOriginal map[string]string // 存储域名规则的原始字符串
domainExceptionsOriginal map[string]string // 存储域名排除规则的原始字符串
regexRules []regexRule
regexExceptions []regexRule
hostsMap map[string]string
blockedDomainsCount map[string]int
resolvedDomainsCount map[string]int
rulesMutex sync.RWMutex
updateCtx context.Context
updateCancel context.CancelFunc
updateRunning bool
localRulesCount int // 本地规则数量
remoteRulesCount int // 远程规则数量
}
// NewShieldManager 创建屏蔽管理器实例
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
ctx, cancel := context.WithCancel(context.Background())
manager := &ShieldManager{
config: config,
domainRules: make(map[string]bool),
domainExceptions: make(map[string]bool),
domainRulesIsLocal: make(map[string]bool),
domainExceptionsIsLocal: make(map[string]bool),
domainRulesSource: make(map[string]string),
domainExceptionsSource: make(map[string]string),
domainRulesOriginal: make(map[string]string),
domainExceptionsOriginal: make(map[string]string),
regexRules: []regexRule{},
regexExceptions: []regexRule{},
hostsMap: make(map[string]string),
blockedDomainsCount: make(map[string]int),
resolvedDomainsCount: make(map[string]int),
updateCtx: ctx,
updateCancel: cancel,
localRulesCount: 0,
remoteRulesCount: 0,
}
// 加载已保存的计数数据
manager.loadStatsData()
return manager
}
// LoadRules 加载屏蔽规则
func (m *ShieldManager) LoadRules() error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 清空现有规则
m.domainRules = make(map[string]bool)
m.domainExceptions = make(map[string]bool)
m.domainRulesIsLocal = make(map[string]bool)
m.domainExceptionsIsLocal = make(map[string]bool)
m.domainRulesSource = make(map[string]string)
m.domainExceptionsSource = make(map[string]string)
m.domainRulesOriginal = make(map[string]string)
m.domainExceptionsOriginal = make(map[string]string)
m.regexRules = []regexRule{}
m.regexExceptions = []regexRule{}
m.hostsMap = make(map[string]string)
m.localRulesCount = 0
m.remoteRulesCount = 0
// 保留计数数据,不随规则重新加载而清空
// 加载本地规则文件
if err := m.loadLocalRules(); err != nil {
logger.Error("加载本地规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载远程规则
if err := m.loadRemoteRules(); err != nil {
logger.Error("加载远程规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载hosts文件
if err := m.loadHosts(); err != nil {
logger.Error("加载hosts文件失败", "error", err)
// 继续执行,不返回错误
}
logger.Info(fmt.Sprintf("规则加载完成,域名规则: %d, 排除规则: %d, 正则规则: %d, hosts规则: %d",
len(m.domainRules), len(m.domainExceptions), len(m.regexRules), len(m.hostsMap)))
return nil
}
// loadLocalRules 加载本地规则文件
func (m *ShieldManager) loadLocalRules() error {
if m.config.LocalRulesFile == "" {
return nil
}
file, err := os.Open(m.config.LocalRulesFile)
if err != nil {
return err
}
defer file.Close()
// 记录加载前的规则数量,用于计算本地规则数量
beforeDomainRules := len(m.domainRules)
beforeRegexRules := len(m.regexRules)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line, true, "本地规则") // 本地规则isLocal=true来源为"本地规则"
}
// 更新本地规则计数
m.localRulesCount = (len(m.domainRules) - beforeDomainRules) + (len(m.regexRules) - beforeRegexRules)
return scanner.Err()
}
// loadRemoteRules 加载远程规则
func (m *ShieldManager) loadRemoteRules() error {
for _, blacklist := range m.config.Blacklists {
if blacklist.Enabled {
if err := m.fetchRemoteRules(blacklist.URL); err != nil {
logger.Error("获取远程规则失败", "url", blacklist.URL, "error", err)
continue
}
}
}
return nil
}
// getCacheFilePath 根据URL生成缓存文件路径
func (m *ShieldManager) getCacheFilePath(url string) string {
// 使用URL的哈希值作为文件名避免文件名冲突
hash := fmt.Sprintf("%x", url)
// 简单处理,移除特殊字符,确保文件名合法
hash = strings.ReplaceAll(hash, "/", "_")
hash = strings.ReplaceAll(hash, "\\", "_")
return filepath.Join(m.config.RemoteRulesCacheDir, hash+".rules")
}
// shouldUpdateCache 检查缓存是否需要更新
func (m *ShieldManager) shouldUpdateCache(cacheFile string) bool {
// 检查文件是否存在
if _, err := os.Stat(cacheFile); os.IsNotExist(err) {
return true
}
// 检查文件修改时间
fileInfo, err := os.Stat(cacheFile)
if err != nil {
return true
}
// 如果缓存文件超过更新间隔时间,则需要更新
return time.Since(fileInfo.ModTime()) > time.Duration(m.config.UpdateInterval)*time.Second
}
// fetchRemoteRules 从远程URL获取规则
func (m *ShieldManager) fetchRemoteRules(url string) error {
// 获取缓存文件路径
cacheFile := m.getCacheFilePath(url)
// 尝试从缓存加载
hasLoadedFromCache := false
if !m.shouldUpdateCache(cacheFile) {
if err := m.loadCachedRules(cacheFile, url); err == nil {
logger.Info("从缓存加载远程规则", "url", url)
hasLoadedFromCache = true
}
}
// 从远程获取规则
resp, err := http.Get(url)
if err != nil {
// 如果从远程获取失败但已经从缓存加载成功则返回nil
if hasLoadedFromCache {
logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "error", err)
return nil
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// 如果状态码不正确但已经从缓存加载成功则返回nil
if hasLoadedFromCache {
logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "statusCode", resp.StatusCode)
return nil
}
return fmt.Errorf("远程服务器返回错误状态码: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
// 保存规则到缓存
if err := m.saveRemoteRulesToCache(cacheFile, body); err != nil {
logger.Warn("保存远程规则缓存失败", "url", url, "error", err)
// 继续处理,不返回错误
}
// 解析并加载规则
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line, false, url) // 远程规则isLocal=false来源为URL
}
return nil
}
// loadCachedRules 从缓存文件加载规则
func (m *ShieldManager) loadCachedRules(filePath string, source string) error {
file, err := os.Open(filePath)
if err != nil {
return err
}
defer file.Close()
// 记录加载前的规则数量,用于计算远程规则数量
beforeDomainRules := len(m.domainRules)
beforeRegexRules := len(m.regexRules)
body, err := ioutil.ReadAll(file)
if err != nil {
return err
}
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line, false, source) // 远程规则isLocal=false来源为URL
}
// 更新远程规则计数
remoteRulesAdded := (len(m.domainRules) - beforeDomainRules) + (len(m.regexRules) - beforeRegexRules)
m.remoteRulesCount += remoteRulesAdded
return nil
}
// saveRemoteRulesToCache 保存远程规则到缓存文件
func (m *ShieldManager) saveRemoteRulesToCache(filePath string, data []byte) error {
// 确保缓存目录存在
if err := os.MkdirAll(m.config.RemoteRulesCacheDir, 0755); err != nil {
return err
}
// 写入文件
return ioutil.WriteFile(filePath, data, 0644)
}
// loadHosts 加载hosts文件
func (m *ShieldManager) loadHosts() error {
if m.config.HostsFile == "" {
return nil
}
file, err := os.Open(m.config.HostsFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
ip := parts[0]
for i := 1; i < len(parts); i++ {
m.hostsMap[parts[i]] = ip
}
}
}
return scanner.Err()
}
// parseRule 解析规则行
func (m *ShieldManager) parseRule(line string, isLocal bool, source string) {
// 保存原始规则用于后续使用
originalLine := line
// 处理注释
if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" {
return
}
// 移除规则选项部分(暂时不处理规则选项)
if strings.Contains(line, "$") {
parts := strings.SplitN(line, "$", 2)
line = parts[0]
// 规则选项暂时不处理
}
// 处理排除规则 (@@前缀表示取消屏蔽)
isException := false
if strings.HasPrefix(line, "@@") {
isException = true
line = strings.TrimPrefix(line, "@@")
}
// 处理不同类型的规则
switch {
case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"):
// AdGuardHome域名规则格式: ||example.com^
domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^")
m.addDomainRule(domain, !isException, isLocal, source, originalLine)
case strings.HasPrefix(line, "||"):
// 精确域名匹配规则
domain := strings.TrimPrefix(line, "||")
m.addDomainRule(domain, !isException, isLocal, source, originalLine)
case strings.HasPrefix(line, "*"):
// 通配符规则,转换为正则表达式
pattern := strings.ReplaceAll(line, "*", ".*")
pattern = "^" + pattern + "$"
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, originalLine, !isException, isLocal, source)
}
case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"):
// 正则表达式匹配规则:/regex/ 格式,不区分大小写
pattern := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/")
// 编译为不区分大小写的正则表达式,确保能匹配域名中任意位置
// 对于像 /domain/ 这样的规则,应该匹配包含 domain 字符串的任何域名
if re, err := regexp.Compile("(?i).*" + regexp.QuoteMeta(pattern) + ".*"); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, originalLine, !isException, isLocal, source)
}
case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"):
// 完整URL匹配规则
urlPattern := strings.TrimPrefix(strings.TrimSuffix(line, "|"), "|")
// 将URL模式转换为正则表达式
pattern := "^" + regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, originalLine, !isException, isLocal, source)
}
case strings.HasPrefix(line, "|"):
// URL开头匹配规则
urlPattern := strings.TrimPrefix(line, "|")
pattern := "^" + regexp.QuoteMeta(urlPattern)
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, originalLine, !isException, isLocal, source)
}
case strings.HasSuffix(line, "|"):
// URL结尾匹配规则
urlPattern := strings.TrimSuffix(line, "|")
pattern := regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, originalLine, !isException, isLocal, source)
}
default:
// 默认作为普通域名规则
m.addDomainRule(line, !isException, isLocal, source, originalLine)
}
}
// parseRuleOptions 解析规则选项
func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string {
options := make(map[string]string)
parts := strings.Split(optionsStr, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if strings.Contains(part, "=") {
kv := strings.SplitN(part, "=", 2)
options[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
} else {
options[part] = ""
}
}
return options
}
// addDomainRule 添加域名规则,支持是否为阻止规则
func (m *ShieldManager) addDomainRule(domain string, block bool, isLocal bool, source string, original string) {
if block {
// 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖
if !isLocal {
if _, exists := m.domainRulesIsLocal[domain]; exists && m.domainRulesIsLocal[domain] {
// 已经存在本地规则,不覆盖
return
}
}
m.domainRules[domain] = true
m.domainRulesIsLocal[domain] = isLocal
m.domainRulesSource[domain] = source
m.domainRulesOriginal[domain] = original
} else {
// 添加到排除规则
// 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖
if !isLocal {
if _, exists := m.domainExceptionsIsLocal[domain]; exists && m.domainExceptionsIsLocal[domain] {
// 已经存在本地规则,不覆盖
return
}
}
m.domainExceptions[domain] = true
m.domainExceptionsIsLocal[domain] = isLocal
m.domainExceptionsSource[domain] = source
m.domainExceptionsOriginal[domain] = original
}
}
// addRegexRule 添加正则表达式规则,支持是否为阻止规则
func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool, isLocal bool, source string) {
rule := regexRule{
pattern: re,
original: original,
isLocal: isLocal,
source: source,
}
if block {
// 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加
if !isLocal {
for _, existingRule := range m.regexRules {
if existingRule.original == original && existingRule.isLocal {
// 已经存在相同的本地规则,不添加
return
}
}
}
m.regexRules = append(m.regexRules, rule)
} else {
// 添加到排除规则
// 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加
if !isLocal {
for _, existingRule := range m.regexExceptions {
if existingRule.original == original && existingRule.isLocal {
// 已经存在相同的本地规则,不添加
return
}
}
}
m.regexExceptions = append(m.regexExceptions, rule)
}
}
// IsBlocked 检查域名是否被屏蔽
// CheckDomainBlockDetails 检查域名是否被屏蔽,并返回详细信息
func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 预处理域名,去除可能的端口号
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
domain = parts[0]
}
result := map[string]interface{}{
"domain": domain,
"blocked": false,
"blockRule": "",
"blockRuleType": "",
"blocksource": "",
"excluded": false,
"excludeRule": "",
"excludeRuleType": "",
"hasHosts": false,
"hostsIP": "",
}
// 检查hosts记录
hostsIP, hasHosts := m.GetHostsIP(domain)
result["hasHosts"] = hasHosts
result["hostsIP"] = hostsIP
// 检查排除规则(优先级最高)
// 检查域名排除规则
if m.domainExceptions[domain] {
result["excluded"] = true
result["excludeRule"] = m.domainExceptionsOriginal[domain]
result["excludeRuleType"] = "exact_domain"
result["blocksource"] = m.domainExceptionsSource[domain]
return result
}
// 检查子域名排除规则
parts := strings.Split(domain, ".")
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainExceptions[subdomain] {
result["excluded"] = true
result["excludeRule"] = m.domainExceptionsOriginal[subdomain]
result["excludeRuleType"] = "subdomain"
result["blocksource"] = m.domainExceptionsSource[subdomain]
return result
}
}
// 检查正则表达式排除规则
for _, re := range m.regexExceptions {
if re.pattern.MatchString(domain) {
result["excluded"] = true
result["excludeRule"] = re.original
result["excludeRuleType"] = "regex"
result["blocksource"] = re.source
return result
}
}
// 检查阻止规则 - 先检查精确域名匹配,再检查子域名匹配
// 检查精确域名匹配
if m.domainRules[domain] {
result["blocked"] = true
result["blockRule"] = m.domainRulesOriginal[domain]
result["blockRuleType"] = "exact_domain"
result["blocksource"] = m.domainRulesSource[domain]
return result
}
// 检查子域名匹配AdGuardHome风格
// 从最长的子域名开始匹配,确保优先级正确
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainRules[subdomain] {
result["blocked"] = true
result["blockRule"] = m.domainRulesOriginal[subdomain]
result["blockRuleType"] = "subdomain"
result["blocksource"] = m.domainRulesSource[subdomain]
return result
}
}
// 检查正则表达式匹配
for _, re := range m.regexRules {
if re.pattern.MatchString(domain) {
result["blocked"] = true
result["blockRule"] = re.original
result["blockRuleType"] = "regex"
result["blocksource"] = re.source
return result
}
}
return result
}
// IsBlocked 检查域名是否被屏蔽(保留原有方法以保持兼容性)
func (m *ShieldManager) IsBlocked(domain string) bool {
details := m.CheckDomainBlockDetails(domain)
return details["blocked"].(bool)
}
// RecordBlockedDomain 记录被屏蔽的域名
func (m *ShieldManager) RecordBlockedDomain(domain string) {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.blockedDomainsCount[domain]++
}
// RecordResolvedDomain 记录被解析的域名
func (m *ShieldManager) RecordResolvedDomain(domain string) {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.resolvedDomainsCount[domain]++
}
// GetTopBlockedDomains 获取最常被屏蔽的域名
func (m *ShieldManager) GetTopBlockedDomains(limit int) []map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 如果没有数据,返回空数组
if len(m.blockedDomainsCount) == 0 {
return []map[string]interface{}{}
}
// 转换为切片以便排序
type domainCount struct {
Domain string
Count int
}
var domains []domainCount
for domain, count := range m.blockedDomainsCount {
domains = append(domains, domainCount{Domain: domain, Count: count})
}
// 按计数降序排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 限制返回数量
if len(domains) > limit {
domains = domains[:limit]
}
// 转换为API响应格式
result := make([]map[string]interface{}, len(domains))
for i, item := range domains {
result[i] = map[string]interface{}{
"domain": item.Domain,
"count": item.Count,
}
}
return result
}
// GetTopResolvedDomains 获取最常被解析的域名
func (m *ShieldManager) GetTopResolvedDomains(limit int) []map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 如果没有数据,返回空数组
if len(m.resolvedDomainsCount) == 0 {
return []map[string]interface{}{}
}
// 转换为切片以便排序
type domainCount struct {
Domain string
Count int
}
var domains []domainCount
for domain, count := range m.resolvedDomainsCount {
domains = append(domains, domainCount{Domain: domain, Count: count})
}
// 按计数降序排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 限制返回数量
if len(domains) > limit {
domains = domains[:limit]
}
// 转换为API响应格式
result := make([]map[string]interface{}, len(domains))
for i, item := range domains {
result[i] = map[string]interface{}{
"domain": item.Domain,
"count": item.Count,
}
}
return result
}
// GetHostsIP 获取hosts文件中的IP映射
func (m *ShieldManager) GetHostsIP(domain string) (string, bool) {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
ip, exists := m.hostsMap[domain]
return ip, exists
}
// AddRule 添加屏蔽规则,用户添加的规则是本地规则
func (m *ShieldManager) AddRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 解析并添加规则到内存isLocal=true表示本地规则来源为"本地规则"
m.parseRule(rule, true, "本地规则")
// 持久化保存规则到文件
if m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveRule 删除屏蔽规则
func (m *ShieldManager) RemoveRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
removed := false
// 清理规则,移除可能的修饰符
cleanRule := rule
// 移除规则结束符
cleanRule = strings.TrimSuffix(cleanRule, "^")
// 尝试多种可能的规则格式
formatsToTry := []string{cleanRule}
// 根据规则类型添加可能的格式变体
if strings.HasPrefix(cleanRule, "@@||") {
// 已有的排除规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "@@||"))
} else if strings.HasPrefix(cleanRule, "||") {
// 已有的域名规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "||"))
} else {
// 可能是裸域名,尝试添加前缀
formatsToTry = append(formatsToTry, "||"+cleanRule, "@@||"+cleanRule)
}
// 尝试所有可能的格式变体来删除规则
for _, format := range formatsToTry {
if removed {
break
}
// 尝试删除排除规则
if strings.HasPrefix(format, "@@||") {
domain := strings.TrimPrefix(format, "@@||")
if _, exists := m.domainExceptions[domain]; exists {
delete(m.domainExceptions, domain)
delete(m.domainExceptionsIsLocal, domain)
delete(m.domainExceptionsSource, domain)
removed = true
break
}
} else if strings.HasPrefix(format, "||") {
// 尝试删除域名规则
domain := strings.TrimPrefix(format, "||")
if _, exists := m.domainRules[domain]; exists {
// 删除主域名规则
delete(m.domainRules, domain)
delete(m.domainRulesIsLocal, domain)
delete(m.domainRulesSource, domain)
removed = true
break
}
} else {
// 尝试直接作为域名删除
if _, exists := m.domainRules[format]; exists {
// 删除主域名规则
delete(m.domainRules, format)
delete(m.domainRulesIsLocal, format)
delete(m.domainRulesSource, format)
removed = true
break
}
if _, exists := m.domainExceptions[format]; exists {
// 删除主排除规则
delete(m.domainExceptions, format)
delete(m.domainExceptionsIsLocal, format)
delete(m.domainExceptionsSource, format)
removed = true
break
}
}
}
// 处理正则表达式规则
if !removed && strings.HasPrefix(cleanRule, "/") && strings.HasSuffix(cleanRule, "/") {
// 检查是否在正则表达式规则中
newRegexRules := []regexRule{}
for _, re := range m.regexRules {
if re.original != rule && re.original != cleanRule {
newRegexRules = append(newRegexRules, re)
} else {
removed = true
}
}
m.regexRules = newRegexRules
// 如果没有从正则规则中找到,检查是否在正则排除规则中
if !removed {
newRegexExceptions := []regexRule{}
for _, re := range m.regexExceptions {
if re.original != rule && re.original != cleanRule {
newRegexExceptions = append(newRegexExceptions, re)
} else {
removed = true
}
}
m.regexExceptions = newRegexExceptions
}
}
// 处理通配符和URL匹配规则
if !removed && (strings.HasPrefix(cleanRule, "*") || strings.HasSuffix(cleanRule, "*") || strings.HasPrefix(cleanRule, "|") || strings.HasSuffix(cleanRule, "|")) {
// 遍历所有规则,找到匹配的规则进行删除
for domain := range m.domainRules {
if domain == cleanRule || domain == rule {
delete(m.domainRules, domain)
delete(m.domainRulesIsLocal, domain)
delete(m.domainRulesSource, domain)
removed = true
break
}
}
if !removed {
for domain := range m.domainExceptions {
if domain == cleanRule || domain == rule {
delete(m.domainExceptions, domain)
delete(m.domainExceptionsIsLocal, domain)
delete(m.domainExceptionsSource, domain)
removed = true
break
}
}
}
}
// 如果没有删除任何规则,尝试删除可能的子域名规则
if !removed {
// 解析原始规则,提取可能的主域名
originalRule := cleanRule
// 移除可能的前缀
originalRule = strings.TrimPrefix(originalRule, "@@||")
originalRule = strings.TrimPrefix(originalRule, "||")
// 检查是否有子域名规则需要删除
// 遍历所有域名规则,删除包含原始规则作为后缀的子域名规则
for domain := range m.domainRules {
if strings.HasSuffix(domain, "."+originalRule) || domain == originalRule {
delete(m.domainRules, domain)
delete(m.domainRulesIsLocal, domain)
delete(m.domainRulesSource, domain)
removed = true
}
}
// 遍历所有排除规则,删除包含原始规则作为后缀的子域名规则
for domain := range m.domainExceptions {
if strings.HasSuffix(domain, "."+originalRule) || domain == originalRule {
delete(m.domainExceptions, domain)
delete(m.domainExceptionsIsLocal, domain)
delete(m.domainExceptionsSource, domain)
removed = true
}
}
}
// 如果有规则被删除,持久化保存更改
if removed && m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// StartAutoUpdate 启动自动更新
func (m *ShieldManager) StartAutoUpdate() {
if m.updateRunning {
return
}
m.updateRunning = true
go func() {
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
defer ticker.Stop()
// 启动自动保存计数数据
go m.startAutoSaveStats()
for {
select {
case <-ticker.C:
logger.Info("开始自动更新规则")
if err := m.LoadRules(); err != nil {
logger.Error("自动更新规则失败", "error", err)
} else {
logger.Info("自动更新规则成功")
}
case <-m.updateCtx.Done():
// 保存计数数据
m.saveStatsData()
m.updateRunning = false
return
}
}
}()
logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval)
// 如果是首次启动,先保存一次数据确保目录存在
go m.saveStatsData()
}
// StopAutoUpdate 停止自动更新
func (m *ShieldManager) StopAutoUpdate() {
m.updateRunning = false
m.updateCancel()
// 保存计数数据
m.saveStatsData()
logger.Info("规则自动更新已停止")
}
// saveRulesToFile 保存规则到文件,只保存本地规则
func (m *ShieldManager) saveRulesToFile() error {
var rules []string
// 添加本地域名规则
for domain, isLocal := range m.domainRulesIsLocal {
if isLocal {
rules = append(rules, "||"+domain)
}
}
// 添加本地正则表达式规则
for _, re := range m.regexRules {
if re.isLocal {
rules = append(rules, re.original)
}
}
// 添加本地排除规则
for domain, isLocal := range m.domainExceptionsIsLocal {
if isLocal {
rules = append(rules, "@@||"+domain)
}
}
// 添加本地正则表达式排除规则
for _, re := range m.regexExceptions {
if re.isLocal {
rules = append(rules, re.original)
}
}
// 写入文件
content := strings.Join(rules, "\n")
return ioutil.WriteFile(m.config.LocalRulesFile, []byte(content), 0644)
}
// AddHostsEntry 添加hosts条目
func (m *ShieldManager) AddHostsEntry(ip, domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.hostsMap[domain] = ip
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveHostsEntry 删除hosts条目
func (m *ShieldManager) RemoveHostsEntry(domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
if _, exists := m.hostsMap[domain]; exists {
delete(m.hostsMap, domain)
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
}
return nil
}
// saveHostsToFile 保存hosts到文件
func (m *ShieldManager) saveHostsToFile() error {
var lines []string
// 添加hosts头部注释
lines = append(lines, "# DNS Server Hosts File")
lines = append(lines, "# Generated by DNS Server")
lines = append(lines, "")
// 添加localhost条目如果不存在
if _, exists := m.hostsMap["localhost"]; !exists {
lines = append(lines, "127.0.0.1 localhost")
lines = append(lines, "::1 localhost")
lines = append(lines, "")
}
// 添加所有hosts条目
for domain, ip := range m.hostsMap {
lines = append(lines, ip+"\t"+domain)
}
// 写入文件
content := strings.Join(lines, "\n")
return ioutil.WriteFile(m.config.HostsFile, []byte(content), 0644)
}
// GetStats 获取规则统计信息
func (m *ShieldManager) GetStats() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
return map[string]interface{}{
"domainRules": len(m.domainRules),
"domainExceptions": len(m.domainExceptions),
"regexRules": len(m.regexRules),
"regexExceptions": len(m.regexExceptions),
"hostsRules": len(m.hostsMap),
"updateInterval": m.config.UpdateInterval,
}
}
// loadStatsData 从文件加载计数数据
func (m *ShieldManager) loadStatsData() {
if m.config.StatsFile == "" {
logger.Info("Shield统计文件路径未配置跳过加载")
return
}
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs(m.config.StatsFile)
if err != nil {
logger.Error("获取Shield统计文件绝对路径失败", "path", m.config.StatsFile, "error", err)
return
}
logger.Debug("尝试加载Shield统计数据", "file", statsFilePath)
// 检查文件是否存在
fileInfo, err := os.Stat(statsFilePath)
if err != nil {
if os.IsNotExist(err) {
logger.Info("Shield统计文件不存在将创建新文件", "file", statsFilePath)
// 初始化空的计数数据
m.rulesMutex.Lock()
m.blockedDomainsCount = make(map[string]int)
m.resolvedDomainsCount = make(map[string]int)
m.rulesMutex.Unlock()
// 尝试立即保存一个有效的空文件
m.saveStatsData()
} else {
logger.Error("检查Shield统计文件失败", "file", statsFilePath, "error", err)
}
return
}
// 检查文件大小
if fileInfo.Size() == 0 {
logger.Warn("Shield统计文件为空将重新初始化", "file", statsFilePath)
m.rulesMutex.Lock()
m.blockedDomainsCount = make(map[string]int)
m.resolvedDomainsCount = make(map[string]int)
m.rulesMutex.Unlock()
m.saveStatsData()
return
}
// 读取文件内容
data, err := ioutil.ReadFile(statsFilePath)
if err != nil {
logger.Error("读取Shield计数数据文件失败", "file", statsFilePath, "error", err)
return
}
// 检查数据长度
if len(data) == 0 {
logger.Warn("读取到的Shield统计数据为空", "file", statsFilePath)
return
}
// 尝试解析JSON
var statsData ShieldStatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
// 记录更详细的错误信息包括数据前50个字符
dataSample := string(data)
if len(dataSample) > 50 {
dataSample = dataSample[:50] + "..."
}
logger.Error("解析Shield计数数据失败",
"file", statsFilePath,
"error", err,
"data_length", len(data),
"data_sample", dataSample)
// 重置为默认空数据
m.rulesMutex.Lock()
m.blockedDomainsCount = make(map[string]int)
m.resolvedDomainsCount = make(map[string]int)
m.rulesMutex.Unlock()
// 尝试保存一个有效的空文件
m.saveStatsData()
return
}
// 恢复计数数据
m.rulesMutex.Lock()
if statsData.BlockedDomainsCount != nil {
m.blockedDomainsCount = statsData.BlockedDomainsCount
} else {
m.blockedDomainsCount = make(map[string]int)
}
if statsData.ResolvedDomainsCount != nil {
m.resolvedDomainsCount = statsData.ResolvedDomainsCount
} else {
m.resolvedDomainsCount = make(map[string]int)
}
m.rulesMutex.Unlock()
logger.Info("Shield计数数据加载成功", "blocked_entries", len(m.blockedDomainsCount), "resolved_entries", len(m.resolvedDomainsCount))
}
// saveStatsData 保存计数数据到文件
func (m *ShieldManager) saveStatsData() {
if m.config.StatsFile == "" {
logger.Debug("Shield统计文件路径未配置跳过保存")
return
}
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs(m.config.StatsFile)
if err != nil {
logger.Error("获取Shield统计文件绝对路径失败", "path", m.config.StatsFile, "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建Shield统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集计数数据
m.rulesMutex.RLock()
statsData := &ShieldStatsData{
BlockedDomainsCount: make(map[string]int),
ResolvedDomainsCount: make(map[string]int),
LastSaved: time.Now(),
}
// 复制数据
for k, v := range m.blockedDomainsCount {
statsData.BlockedDomainsCount[k] = v
}
for k, v := range m.resolvedDomainsCount {
statsData.ResolvedDomainsCount[k] = v
}
m.rulesMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化Shield计数数据失败", "error", err)
return
}
// 使用临时文件先写入,然后重命名,避免文件损坏
tempFilePath := statsFilePath + ".tmp"
err = ioutil.WriteFile(tempFilePath, jsonData, 0644)
if err != nil {
logger.Error("写入临时Shield统计文件失败", "file", tempFilePath, "error", err)
return
}
// 原子操作重命名文件
err = os.Rename(tempFilePath, statsFilePath)
if err != nil {
logger.Error("重命名Shield统计文件失败", "temp", tempFilePath, "dest", statsFilePath, "error", err)
// 尝试清理临时文件
os.Remove(tempFilePath)
return
}
logger.Info("Shield计数数据保存成功", "file", statsFilePath, "blocked_entries", len(statsData.BlockedDomainsCount), "resolved_entries", len(statsData.ResolvedDomainsCount))
}
// startAutoSaveStats 启动计数数据自动保存功能
func (m *ShieldManager) startAutoSaveStats() {
if m.config.StatsFile == "" || m.config.StatsSaveInterval <= 0 {
return
}
ticker := time.NewTicker(time.Duration(m.config.StatsSaveInterval) * time.Second)
defer ticker.Stop()
logger.Info("启动Shield计数数据自动保存功能", "interval", m.config.StatsSaveInterval, "file", m.config.StatsFile)
// 定期保存数据
for {
select {
case <-ticker.C:
m.saveStatsData()
case <-m.updateCtx.Done():
return
}
}
}
// GetBlacklists 获取所有黑名单配置
func (m *ShieldManager) GetBlacklists() []config.BlacklistEntry {
return m.config.Blacklists
}
// UpdateBlacklist 更新黑名单配置
func (m *ShieldManager) UpdateBlacklist(blacklists []config.BlacklistEntry) {
m.config.Blacklists = blacklists
}
// GetAllHosts 获取所有hosts条目
func (m *ShieldManager) GetAllHosts() map[string]string {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 返回hostsMap的副本避免并发问题
hostsCopy := make(map[string]string, len(m.hostsMap))
for domain, ip := range m.hostsMap {
hostsCopy[domain] = ip
}
return hostsCopy
}
// GetHostsCount 获取hosts条目数量
func (m *ShieldManager) GetHostsCount() int {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
return len(m.hostsMap)
}
// GetLocalRules 获取仅本地规则
func (m *ShieldManager) GetLocalRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表只包含本地规则
domainRulesList := make([]string, 0)
for domain, isLocal := range m.domainRulesIsLocal {
if isLocal && m.domainRules[domain] {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
}
domainExceptionsList := make([]string, 0)
for domain, isLocal := range m.domainExceptionsIsLocal {
if isLocal && m.domainExceptions[domain] {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
}
// 获取本地正则规则原始字符串
regexRulesList := make([]string, 0)
for _, re := range m.regexRules {
if re.isLocal {
regexRulesList = append(regexRulesList, re.original)
}
}
// 获取本地正则排除规则原始字符串
regexExceptionsList := make([]string, 0)
for _, re := range m.regexExceptions {
if re.isLocal {
regexExceptionsList = append(regexExceptionsList, re.original)
}
}
// 计算本地规则数量
localDomainRulesCount := 0
for _, isLocal := range m.domainRulesIsLocal {
if isLocal {
localDomainRulesCount++
}
}
localRegexRulesCount := 0
for _, re := range m.regexRules {
if re.isLocal {
localRegexRulesCount++
}
}
localRulesCount := localDomainRulesCount + localRegexRulesCount
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"localRulesCount": localRulesCount,
"localDomainRulesCount": localDomainRulesCount,
"localRegexRulesCount": localRegexRulesCount,
}
}
// GetRemoteRules 获取仅远程规则
func (m *ShieldManager) GetRemoteRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表只包含远程规则
domainRulesList := make([]string, 0)
for domain, isLocal := range m.domainRulesIsLocal {
if !isLocal && m.domainRules[domain] {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
}
domainExceptionsList := make([]string, 0)
for domain, isLocal := range m.domainExceptionsIsLocal {
if !isLocal && m.domainExceptions[domain] {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
}
// 获取远程正则规则原始字符串
regexRulesList := make([]string, 0)
for _, re := range m.regexRules {
if !re.isLocal {
regexRulesList = append(regexRulesList, re.original)
}
}
// 获取远程正则排除规则原始字符串
regexExceptionsList := make([]string, 0)
for _, re := range m.regexExceptions {
if !re.isLocal {
regexExceptionsList = append(regexExceptionsList, re.original)
}
}
// 计算远程规则数量
remoteDomainRulesCount := 0
for _, isLocal := range m.domainRulesIsLocal {
if !isLocal {
remoteDomainRulesCount++
}
}
remoteRegexRulesCount := 0
for _, re := range m.regexRules {
if !re.isLocal {
remoteRegexRulesCount++
}
}
remoteRulesCount := remoteDomainRulesCount + remoteRegexRulesCount
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"remoteRulesCount": remoteRulesCount,
"remoteDomainRulesCount": remoteDomainRulesCount,
"remoteRegexRulesCount": remoteRegexRulesCount,
"blacklists": m.config.Blacklists,
}
}
// GetRules 获取所有规则
func (m *ShieldManager) GetRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表
domainRulesList := make([]string, 0, len(m.domainRules))
for domain := range m.domainRules {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
domainExceptionsList := make([]string, 0, len(m.domainExceptions))
for domain := range m.domainExceptions {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
// 获取正则规则原始字符串
regexRulesList := make([]string, 0, len(m.regexRules))
for _, re := range m.regexRules {
regexRulesList = append(regexRulesList, re.original)
}
// 获取正则排除规则原始字符串
regexExceptionsList := make([]string, 0, len(m.regexExceptions))
for _, re := range m.regexExceptions {
regexExceptionsList = append(regexExceptionsList, re.original)
}
// 获取hosts规则
hostsRulesList := make([]string, 0, len(m.hostsMap))
for domain, ip := range m.hostsMap {
hostsRulesList = append(hostsRulesList, ip+"\t"+domain)
}
// 计算总规则数量
totalRulesCount := len(m.domainRules) + len(m.regexRules)
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"hostsRules": hostsRulesList,
"blacklists": m.config.Blacklists,
"localRulesCount": m.localRulesCount,
"remoteRulesCount": m.remoteRulesCount,
"totalRulesCount": totalRulesCount,
}
}