package shield import ( "bufio" "context" "encoding/json" "fmt" "io/ioutil" "net/http" "os" "path/filepath" "regexp" "sort" "strings" "sync" "time" "dns-server/config" "dns-server/logger" ) // ShieldStatsData 用于持久化的Shield统计数据 type ShieldStatsData struct { BlockedDomainsCount map[string]int `json:"blockedDomainsCount"` ResolvedDomainsCount map[string]int `json:"resolvedDomainsCount"` LastSaved time.Time `json:"lastSaved"` } // regexRule 正则规则结构,包含编译后的表达式和原始字符串 type regexRule struct { pattern *regexp.Regexp original string isLocal bool // 是否为本地规则 source string // 规则来源 } // ShieldManager 屏蔽管理器 type ShieldManager struct { config *config.ShieldConfig domainRules map[string]bool domainExceptions map[string]bool domainRulesIsLocal map[string]bool // 标记域名规则是否为本地规则 domainExceptionsIsLocal map[string]bool // 标记域名排除规则是否为本地规则 domainRulesSource map[string]string // 标记域名规则来源 domainExceptionsSource map[string]string // 标记域名排除规则来源 regexRules []regexRule regexExceptions []regexRule hostsMap map[string]string blockedDomainsCount map[string]int resolvedDomainsCount map[string]int rulesMutex sync.RWMutex updateCtx context.Context updateCancel context.CancelFunc updateRunning bool localRulesCount int // 本地规则数量 remoteRulesCount int // 远程规则数量 } // NewShieldManager 创建屏蔽管理器实例 func NewShieldManager(config *config.ShieldConfig) *ShieldManager { ctx, cancel := context.WithCancel(context.Background()) manager := &ShieldManager{ config: config, domainRules: make(map[string]bool), domainExceptions: make(map[string]bool), domainRulesIsLocal: make(map[string]bool), domainExceptionsIsLocal: make(map[string]bool), domainRulesSource: make(map[string]string), domainExceptionsSource: make(map[string]string), regexRules: []regexRule{}, regexExceptions: []regexRule{}, hostsMap: make(map[string]string), blockedDomainsCount: make(map[string]int), resolvedDomainsCount: make(map[string]int), updateCtx: ctx, updateCancel: cancel, localRulesCount: 0, remoteRulesCount: 0, } // 加载已保存的计数数据 manager.loadStatsData() return manager } // LoadRules 加载屏蔽规则 func (m *ShieldManager) LoadRules() error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() // 清空现有规则 m.domainRules = make(map[string]bool) m.domainExceptions = make(map[string]bool) m.domainRulesIsLocal = make(map[string]bool) m.domainExceptionsIsLocal = make(map[string]bool) m.domainRulesSource = make(map[string]string) m.domainExceptionsSource = make(map[string]string) m.regexRules = []regexRule{} m.regexExceptions = []regexRule{} m.hostsMap = make(map[string]string) m.localRulesCount = 0 m.remoteRulesCount = 0 // 保留计数数据,不随规则重新加载而清空 // 加载本地规则文件 if err := m.loadLocalRules(); err != nil { logger.Error("加载本地规则失败", "error", err) // 继续执行,不返回错误 } // 加载远程规则 if err := m.loadRemoteRules(); err != nil { logger.Error("加载远程规则失败", "error", err) // 继续执行,不返回错误 } // 加载hosts文件 if err := m.loadHosts(); err != nil { logger.Error("加载hosts文件失败", "error", err) // 继续执行,不返回错误 } logger.Info(fmt.Sprintf("规则加载完成,域名规则: %d, 排除规则: %d, 正则规则: %d, hosts规则: %d", len(m.domainRules), len(m.domainExceptions), len(m.regexRules), len(m.hostsMap))) return nil } // loadLocalRules 加载本地规则文件 func (m *ShieldManager) loadLocalRules() error { if m.config.LocalRulesFile == "" { return nil } file, err := os.Open(m.config.LocalRulesFile) if err != nil { return err } defer file.Close() // 记录加载前的规则数量,用于计算本地规则数量 beforeDomainRules := len(m.domainRules) beforeRegexRules := len(m.regexRules) scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } m.parseRule(line, true, "本地规则") // 本地规则,isLocal=true,来源为"本地规则" } // 更新本地规则计数 m.localRulesCount = (len(m.domainRules) - beforeDomainRules) + (len(m.regexRules) - beforeRegexRules) return scanner.Err() } // loadRemoteRules 加载远程规则 func (m *ShieldManager) loadRemoteRules() error { for _, blacklist := range m.config.Blacklists { if blacklist.Enabled { if err := m.fetchRemoteRules(blacklist.URL); err != nil { logger.Error("获取远程规则失败", "url", blacklist.URL, "error", err) continue } } } return nil } // getCacheFilePath 根据URL生成缓存文件路径 func (m *ShieldManager) getCacheFilePath(url string) string { // 使用URL的哈希值作为文件名,避免文件名冲突 hash := fmt.Sprintf("%x", url) // 简单处理,移除特殊字符,确保文件名合法 hash = strings.ReplaceAll(hash, "/", "_") hash = strings.ReplaceAll(hash, "\\", "_") return filepath.Join(m.config.RemoteRulesCacheDir, hash+".rules") } // shouldUpdateCache 检查缓存是否需要更新 func (m *ShieldManager) shouldUpdateCache(cacheFile string) bool { // 检查文件是否存在 if _, err := os.Stat(cacheFile); os.IsNotExist(err) { return true } // 检查文件修改时间 fileInfo, err := os.Stat(cacheFile) if err != nil { return true } // 如果缓存文件超过更新间隔时间,则需要更新 return time.Since(fileInfo.ModTime()) > time.Duration(m.config.UpdateInterval)*time.Second } // fetchRemoteRules 从远程URL获取规则 func (m *ShieldManager) fetchRemoteRules(url string) error { // 获取缓存文件路径 cacheFile := m.getCacheFilePath(url) // 尝试从缓存加载 hasLoadedFromCache := false if !m.shouldUpdateCache(cacheFile) { if err := m.loadCachedRules(cacheFile, url); err == nil { logger.Info("从缓存加载远程规则", "url", url) hasLoadedFromCache = true } } // 从远程获取规则 resp, err := http.Get(url) if err != nil { // 如果从远程获取失败,但已经从缓存加载成功,则返回nil if hasLoadedFromCache { logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "error", err) return nil } return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { // 如果状态码不正确,但已经从缓存加载成功,则返回nil if hasLoadedFromCache { logger.Warn("远程规则更新失败,使用缓存版本", "url", url, "statusCode", resp.StatusCode) return nil } return fmt.Errorf("远程服务器返回错误状态码: %d", resp.StatusCode) } body, err := ioutil.ReadAll(resp.Body) if err != nil { return err } // 保存规则到缓存 if err := m.saveRemoteRulesToCache(cacheFile, body); err != nil { logger.Warn("保存远程规则缓存失败", "url", url, "error", err) // 继续处理,不返回错误 } // 解析并加载规则 lines := strings.Split(string(body), "\n") for _, line := range lines { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } m.parseRule(line, false, url) // 远程规则,isLocal=false,来源为URL } return nil } // loadCachedRules 从缓存文件加载规则 func (m *ShieldManager) loadCachedRules(filePath string, source string) error { file, err := os.Open(filePath) if err != nil { return err } defer file.Close() // 记录加载前的规则数量,用于计算远程规则数量 beforeDomainRules := len(m.domainRules) beforeRegexRules := len(m.regexRules) body, err := ioutil.ReadAll(file) if err != nil { return err } lines := strings.Split(string(body), "\n") for _, line := range lines { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue } m.parseRule(line, false, source) // 远程规则,isLocal=false,来源为URL } // 更新远程规则计数 remoteRulesAdded := (len(m.domainRules) - beforeDomainRules) + (len(m.regexRules) - beforeRegexRules) m.remoteRulesCount += remoteRulesAdded return nil } // saveRemoteRulesToCache 保存远程规则到缓存文件 func (m *ShieldManager) saveRemoteRulesToCache(filePath string, data []byte) error { // 确保缓存目录存在 if err := os.MkdirAll(m.config.RemoteRulesCacheDir, 0755); err != nil { return err } // 写入文件 return ioutil.WriteFile(filePath, data, 0644) } // loadHosts 加载hosts文件 func (m *ShieldManager) loadHosts() error { if m.config.HostsFile == "" { return nil } file, err := os.Open(m.config.HostsFile) if err != nil { return err } defer file.Close() scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } parts := strings.Fields(line) if len(parts) >= 2 { ip := parts[0] for i := 1; i < len(parts); i++ { m.hostsMap[parts[i]] = ip } } } return scanner.Err() } // parseRule 解析规则行 func (m *ShieldManager) parseRule(line string, isLocal bool, source string) { // 处理注释 if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" { return } // 移除规则选项部分(暂时不处理规则选项) if strings.Contains(line, "$") { parts := strings.SplitN(line, "$", 2) line = parts[0] // 规则选项暂时不处理 } // 处理排除规则 (@@前缀表示取消屏蔽) isException := false if strings.HasPrefix(line, "@@") { isException = true line = strings.TrimPrefix(line, "@@") } // 处理不同类型的规则 switch { case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"): // AdGuardHome域名规则格式: ||example.com^ domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^") m.addDomainRule(domain, !isException, isLocal, source) case strings.HasPrefix(line, "||"): // 精确域名匹配规则 domain := strings.TrimPrefix(line, "||") m.addDomainRule(domain, !isException, isLocal, source) case strings.HasPrefix(line, "*"): // 通配符规则,转换为正则表达式 pattern := strings.ReplaceAll(line, "*", ".*") pattern = "^" + pattern + "$" if re, err := regexp.Compile(pattern); err == nil { // 保存原始规则字符串 m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"): // 关键字匹配规则:/keyword/ 格式,不区分大小写,字面量匹配特殊字符 keyword := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/") // 转义特殊字符,确保字面量匹配 quotedKeyword := regexp.QuoteMeta(keyword) // 编译为不区分大小写的正则表达式,匹配域名中任意位置 if re, err := regexp.Compile("(?i)" + quotedKeyword); err == nil { // 保存原始规则字符串 m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"): // 完整URL匹配规则 urlPattern := strings.TrimPrefix(strings.TrimSuffix(line, "|"), "|") // 将URL模式转换为正则表达式 pattern := "^" + regexp.QuoteMeta(urlPattern) + "$" if re, err := regexp.Compile(pattern); err == nil { m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "|"): // URL开头匹配规则 urlPattern := strings.TrimPrefix(line, "|") pattern := "^" + regexp.QuoteMeta(urlPattern) if re, err := regexp.Compile(pattern); err == nil { m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasSuffix(line, "|"): // URL结尾匹配规则 urlPattern := strings.TrimSuffix(line, "|") pattern := regexp.QuoteMeta(urlPattern) + "$" if re, err := regexp.Compile(pattern); err == nil { m.addRegexRule(re, line, !isException, isLocal, source) } default: // 默认作为普通域名规则 m.addDomainRule(line, !isException, isLocal, source) } } // parseRuleOptions 解析规则选项 func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string { options := make(map[string]string) parts := strings.Split(optionsStr, ",") for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } if strings.Contains(part, "=") { kv := strings.SplitN(part, "=", 2) options[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) } else { options[part] = "" } } return options } // addDomainRule 添加域名规则,支持是否为阻止规则 func (m *ShieldManager) addDomainRule(domain string, block bool, isLocal bool, source string) { if block { // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 if !isLocal { if _, exists := m.domainRulesIsLocal[domain]; exists && m.domainRulesIsLocal[domain] { // 已经存在本地规则,不覆盖 return } } m.domainRules[domain] = true m.domainRulesIsLocal[domain] = isLocal m.domainRulesSource[domain] = source // 添加所有子域名的匹配支持 parts := strings.Split(domain, ".") if len(parts) > 1 { // 为二级域名和顶级域名添加规则 for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 if !isLocal { if _, exists := m.domainRulesIsLocal[subdomain]; exists && m.domainRulesIsLocal[subdomain] { // 已经存在本地规则,不覆盖 continue } } m.domainRules[subdomain] = true m.domainRulesIsLocal[subdomain] = isLocal m.domainRulesSource[subdomain] = source } } } else { // 添加到排除规则 // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 if !isLocal { if _, exists := m.domainExceptionsIsLocal[domain]; exists && m.domainExceptionsIsLocal[domain] { // 已经存在本地规则,不覆盖 return } } m.domainExceptions[domain] = true m.domainExceptionsIsLocal[domain] = isLocal m.domainExceptionsSource[domain] = source // 为子域名也添加排除规则 parts := strings.Split(domain, ".") if len(parts) > 1 { for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 if !isLocal { if _, exists := m.domainExceptionsIsLocal[subdomain]; exists && m.domainExceptionsIsLocal[subdomain] { // 已经存在本地规则,不覆盖 continue } } m.domainExceptions[subdomain] = true m.domainExceptionsIsLocal[subdomain] = isLocal m.domainExceptionsSource[subdomain] = source } } } } // addRegexRule 添加正则表达式规则,支持是否为阻止规则 func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool, isLocal bool, source string) { rule := regexRule{ pattern: re, original: original, isLocal: isLocal, source: source, } if block { // 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加 if !isLocal { for _, existingRule := range m.regexRules { if existingRule.original == original && existingRule.isLocal { // 已经存在相同的本地规则,不添加 return } } } m.regexRules = append(m.regexRules, rule) } else { // 添加到排除规则 // 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加 if !isLocal { for _, existingRule := range m.regexExceptions { if existingRule.original == original && existingRule.isLocal { // 已经存在相同的本地规则,不添加 return } } } m.regexExceptions = append(m.regexExceptions, rule) } } // IsBlocked 检查域名是否被屏蔽 // CheckDomainBlockDetails 检查域名是否被屏蔽,并返回详细信息 func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interface{} { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() // 预处理域名,去除可能的端口号 if strings.Contains(domain, ":") { parts := strings.Split(domain, ":") domain = parts[0] } result := map[string]interface{}{ "domain": domain, "blocked": false, "blockRule": "", "blockRuleType": "", "blocksource": "", "excluded": false, "excludeRule": "", "excludeRuleType": "", "hasHosts": false, "hostsIP": "", } // 检查hosts记录 hostsIP, hasHosts := m.GetHostsIP(domain) result["hasHosts"] = hasHosts result["hostsIP"] = hostsIP // 检查排除规则(优先级最高) // 检查域名排除规则 if m.domainExceptions[domain] { result["excluded"] = true result["excludeRule"] = domain result["excludeRuleType"] = "exact_domain" result["blocksource"] = m.domainExceptionsSource[domain] return result } // 检查子域名排除规则 parts := strings.Split(domain, ".") for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") if m.domainExceptions[subdomain] { result["excluded"] = true result["excludeRule"] = subdomain result["excludeRuleType"] = "subdomain" result["blocksource"] = m.domainExceptionsSource[subdomain] return result } } // 检查正则表达式排除规则 for _, re := range m.regexExceptions { if re.pattern.MatchString(domain) { result["excluded"] = true result["excludeRule"] = re.original result["excludeRuleType"] = "regex" result["blocksource"] = re.source return result } } // 检查阻止规则 // 检查精确域名匹配 if m.domainRules[domain] { result["blocked"] = true result["blockRule"] = domain result["blockRuleType"] = "exact_domain" result["blocksource"] = m.domainRulesSource[domain] return result } // 检查子域名匹配(AdGuardHome风格) // 从最长的子域名开始匹配,确保优先级正确 for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") if m.domainRules[subdomain] { result["blocked"] = true result["blockRule"] = subdomain result["blockRuleType"] = "subdomain" result["blocksource"] = m.domainRulesSource[subdomain] return result } } // 检查正则表达式匹配 for _, re := range m.regexRules { if re.pattern.MatchString(domain) { result["blocked"] = true result["blockRule"] = re.original result["blockRuleType"] = "regex" result["blocksource"] = re.source return result } } return result } // IsBlocked 检查域名是否被屏蔽(保留原有方法以保持兼容性) func (m *ShieldManager) IsBlocked(domain string) bool { details := m.CheckDomainBlockDetails(domain) return details["blocked"].(bool) } // RecordBlockedDomain 记录被屏蔽的域名 func (m *ShieldManager) RecordBlockedDomain(domain string) { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() m.blockedDomainsCount[domain]++ } // RecordResolvedDomain 记录被解析的域名 func (m *ShieldManager) RecordResolvedDomain(domain string) { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() m.resolvedDomainsCount[domain]++ } // GetTopBlockedDomains 获取最常被屏蔽的域名 func (m *ShieldManager) GetTopBlockedDomains(limit int) []map[string]interface{} { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() // 如果没有数据,返回空数组 if len(m.blockedDomainsCount) == 0 { return []map[string]interface{}{} } // 转换为切片以便排序 type domainCount struct { Domain string Count int } var domains []domainCount for domain, count := range m.blockedDomainsCount { domains = append(domains, domainCount{Domain: domain, Count: count}) } // 按计数降序排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 限制返回数量 if len(domains) > limit { domains = domains[:limit] } // 转换为API响应格式 result := make([]map[string]interface{}, len(domains)) for i, item := range domains { result[i] = map[string]interface{}{ "domain": item.Domain, "count": item.Count, } } return result } // GetTopResolvedDomains 获取最常被解析的域名 func (m *ShieldManager) GetTopResolvedDomains(limit int) []map[string]interface{} { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() // 如果没有数据,返回空数组 if len(m.resolvedDomainsCount) == 0 { return []map[string]interface{}{} } // 转换为切片以便排序 type domainCount struct { Domain string Count int } var domains []domainCount for domain, count := range m.resolvedDomainsCount { domains = append(domains, domainCount{Domain: domain, Count: count}) } // 按计数降序排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 限制返回数量 if len(domains) > limit { domains = domains[:limit] } // 转换为API响应格式 result := make([]map[string]interface{}, len(domains)) for i, item := range domains { result[i] = map[string]interface{}{ "domain": item.Domain, "count": item.Count, } } return result } // GetHostsIP 获取hosts文件中的IP映射 func (m *ShieldManager) GetHostsIP(domain string) (string, bool) { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() ip, exists := m.hostsMap[domain] return ip, exists } // AddRule 添加屏蔽规则,用户添加的规则是本地规则 func (m *ShieldManager) AddRule(rule string) error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() // 解析并添加规则到内存,isLocal=true表示本地规则,来源为"本地规则" m.parseRule(rule, true, "本地规则") // 持久化保存规则到文件 if m.config.LocalRulesFile != "" { if err := m.saveRulesToFile(); err != nil { logger.Error("保存规则到文件失败", "error", err) return err } } return nil } // RemoveRule 删除屏蔽规则 func (m *ShieldManager) RemoveRule(rule string) error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() removed := false // 清理规则,移除可能的修饰符 cleanRule := rule // 移除规则结束符 cleanRule = strings.TrimSuffix(cleanRule, "^") // 尝试多种可能的规则格式 formatsToTry := []string{cleanRule} // 根据规则类型添加可能的格式变体 if strings.HasPrefix(cleanRule, "@@||") { // 已有的排除规则格式,也尝试去掉前缀 formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "@@||")) } else if strings.HasPrefix(cleanRule, "||") { // 已有的域名规则格式,也尝试去掉前缀 formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "||")) } else { // 可能是裸域名,尝试添加前缀 formatsToTry = append(formatsToTry, "||"+cleanRule, "@@||"+cleanRule) } // 尝试所有可能的格式变体来删除规则 for _, format := range formatsToTry { if removed { break } // 尝试删除排除规则 if strings.HasPrefix(format, "@@||") { domain := strings.TrimPrefix(format, "@@||") if _, exists := m.domainExceptions[domain]; exists { delete(m.domainExceptions, domain) removed = true break } } else if strings.HasPrefix(format, "||") { // 尝试删除域名规则 domain := strings.TrimPrefix(format, "||") if _, exists := m.domainRules[domain]; exists { delete(m.domainRules, domain) removed = true break } } else { // 尝试直接作为域名删除 if _, exists := m.domainRules[format]; exists { delete(m.domainRules, format) removed = true break } if _, exists := m.domainExceptions[format]; exists { delete(m.domainExceptions, format) removed = true break } } } // 处理正则表达式规则 if !removed && strings.HasPrefix(cleanRule, "/") && strings.HasSuffix(cleanRule, "/") { pattern := strings.TrimPrefix(strings.TrimSuffix(cleanRule, "/"), "/") // 检查是否在正则表达式规则中 newRegexRules := []regexRule{} for _, re := range m.regexRules { if re.pattern.String() != pattern { newRegexRules = append(newRegexRules, re) } else { removed = true } } m.regexRules = newRegexRules // 如果没有从正则规则中找到,检查是否在正则排除规则中 if !removed { newRegexExceptions := []regexRule{} for _, re := range m.regexExceptions { if re.pattern.String() != pattern { newRegexExceptions = append(newRegexExceptions, re) } else { removed = true } } m.regexExceptions = newRegexExceptions } } // 处理通配符和URL匹配规则 if !removed && (strings.HasPrefix(cleanRule, "*") || strings.HasSuffix(cleanRule, "*") || strings.HasPrefix(cleanRule, "|") || strings.HasSuffix(cleanRule, "|")) { // 遍历所有规则,找到匹配的规则进行删除 for domain := range m.domainRules { if domain == cleanRule || domain == rule { delete(m.domainRules, domain) removed = true break } } if !removed { for domain := range m.domainExceptions { if domain == cleanRule || domain == rule { delete(m.domainExceptions, domain) removed = true break } } } } // 如果有规则被删除,持久化保存更改 if removed && m.config.LocalRulesFile != "" { if err := m.saveRulesToFile(); err != nil { logger.Error("保存规则到文件失败", "error", err) return err } } return nil } // StartAutoUpdate 启动自动更新 func (m *ShieldManager) StartAutoUpdate() { if m.updateRunning { return } m.updateRunning = true go func() { ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second) defer ticker.Stop() // 启动自动保存计数数据 go m.startAutoSaveStats() for { select { case <-ticker.C: logger.Info("开始自动更新规则") if err := m.LoadRules(); err != nil { logger.Error("自动更新规则失败", "error", err) } else { logger.Info("自动更新规则成功") } case <-m.updateCtx.Done(): // 保存计数数据 m.saveStatsData() m.updateRunning = false return } } }() logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval) // 如果是首次启动,先保存一次数据确保目录存在 go m.saveStatsData() } // StopAutoUpdate 停止自动更新 func (m *ShieldManager) StopAutoUpdate() { m.updateRunning = false m.updateCancel() // 保存计数数据 m.saveStatsData() logger.Info("规则自动更新已停止") } // saveRulesToFile 保存规则到文件,只保存本地规则 func (m *ShieldManager) saveRulesToFile() error { var rules []string // 添加本地域名规则 for domain, isLocal := range m.domainRulesIsLocal { if isLocal { rules = append(rules, "||"+domain) } } // 添加本地正则表达式规则 for _, re := range m.regexRules { if re.isLocal { rules = append(rules, re.original) } } // 添加本地排除规则 for domain, isLocal := range m.domainExceptionsIsLocal { if isLocal { rules = append(rules, "@@||"+domain) } } // 添加本地正则表达式排除规则 for _, re := range m.regexExceptions { if re.isLocal { rules = append(rules, re.original) } } // 写入文件 content := strings.Join(rules, "\n") return ioutil.WriteFile(m.config.LocalRulesFile, []byte(content), 0644) } // AddHostsEntry 添加hosts条目 func (m *ShieldManager) AddHostsEntry(ip, domain string) error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() m.hostsMap[domain] = ip // 持久化保存到hosts文件 if m.config.HostsFile != "" { if err := m.saveHostsToFile(); err != nil { logger.Error("保存hosts到文件失败", "error", err) return err } } return nil } // RemoveHostsEntry 删除hosts条目 func (m *ShieldManager) RemoveHostsEntry(domain string) error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() if _, exists := m.hostsMap[domain]; exists { delete(m.hostsMap, domain) // 持久化保存到hosts文件 if m.config.HostsFile != "" { if err := m.saveHostsToFile(); err != nil { logger.Error("保存hosts到文件失败", "error", err) return err } } } return nil } // saveHostsToFile 保存hosts到文件 func (m *ShieldManager) saveHostsToFile() error { var lines []string // 添加hosts头部注释 lines = append(lines, "# DNS Server Hosts File") lines = append(lines, "# Generated by DNS Server") lines = append(lines, "") // 添加localhost条目(如果不存在) if _, exists := m.hostsMap["localhost"]; !exists { lines = append(lines, "127.0.0.1 localhost") lines = append(lines, "::1 localhost") lines = append(lines, "") } // 添加所有hosts条目 for domain, ip := range m.hostsMap { lines = append(lines, ip+"\t"+domain) } // 写入文件 content := strings.Join(lines, "\n") return ioutil.WriteFile(m.config.HostsFile, []byte(content), 0644) } // GetStats 获取规则统计信息 func (m *ShieldManager) GetStats() map[string]interface{} { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() return map[string]interface{}{ "domainRules": len(m.domainRules), "domainExceptions": len(m.domainExceptions), "regexRules": len(m.regexRules), "regexExceptions": len(m.regexExceptions), "hostsRules": len(m.hostsMap), "updateInterval": m.config.UpdateInterval, } } // loadStatsData 从文件加载计数数据 func (m *ShieldManager) loadStatsData() { if m.config.StatsFile == "" { logger.Info("Shield统计文件路径未配置,跳过加载") return } // 获取绝对路径以避免工作目录问题 statsFilePath, err := filepath.Abs(m.config.StatsFile) if err != nil { logger.Error("获取Shield统计文件绝对路径失败", "path", m.config.StatsFile, "error", err) return } logger.Debug("尝试加载Shield统计数据", "file", statsFilePath) // 检查文件是否存在 fileInfo, err := os.Stat(statsFilePath) if err != nil { if os.IsNotExist(err) { logger.Info("Shield统计文件不存在,将创建新文件", "file", statsFilePath) // 初始化空的计数数据 m.rulesMutex.Lock() m.blockedDomainsCount = make(map[string]int) m.resolvedDomainsCount = make(map[string]int) m.rulesMutex.Unlock() // 尝试立即保存一个有效的空文件 m.saveStatsData() } else { logger.Error("检查Shield统计文件失败", "file", statsFilePath, "error", err) } return } // 检查文件大小 if fileInfo.Size() == 0 { logger.Warn("Shield统计文件为空,将重新初始化", "file", statsFilePath) m.rulesMutex.Lock() m.blockedDomainsCount = make(map[string]int) m.resolvedDomainsCount = make(map[string]int) m.rulesMutex.Unlock() m.saveStatsData() return } // 读取文件内容 data, err := ioutil.ReadFile(statsFilePath) if err != nil { logger.Error("读取Shield计数数据文件失败", "file", statsFilePath, "error", err) return } // 检查数据长度 if len(data) == 0 { logger.Warn("读取到的Shield统计数据为空", "file", statsFilePath) return } // 尝试解析JSON var statsData ShieldStatsData err = json.Unmarshal(data, &statsData) if err != nil { // 记录更详细的错误信息,包括数据前50个字符 dataSample := string(data) if len(dataSample) > 50 { dataSample = dataSample[:50] + "..." } logger.Error("解析Shield计数数据失败", "file", statsFilePath, "error", err, "data_length", len(data), "data_sample", dataSample) // 重置为默认空数据 m.rulesMutex.Lock() m.blockedDomainsCount = make(map[string]int) m.resolvedDomainsCount = make(map[string]int) m.rulesMutex.Unlock() // 尝试保存一个有效的空文件 m.saveStatsData() return } // 恢复计数数据 m.rulesMutex.Lock() if statsData.BlockedDomainsCount != nil { m.blockedDomainsCount = statsData.BlockedDomainsCount } else { m.blockedDomainsCount = make(map[string]int) } if statsData.ResolvedDomainsCount != nil { m.resolvedDomainsCount = statsData.ResolvedDomainsCount } else { m.resolvedDomainsCount = make(map[string]int) } m.rulesMutex.Unlock() logger.Info("Shield计数数据加载成功", "blocked_entries", len(m.blockedDomainsCount), "resolved_entries", len(m.resolvedDomainsCount)) } // saveStatsData 保存计数数据到文件 func (m *ShieldManager) saveStatsData() { if m.config.StatsFile == "" { logger.Debug("Shield统计文件路径未配置,跳过保存") return } // 获取绝对路径以避免工作目录问题 statsFilePath, err := filepath.Abs(m.config.StatsFile) if err != nil { logger.Error("获取Shield统计文件绝对路径失败", "path", m.config.StatsFile, "error", err) return } // 创建数据目录 statsDir := filepath.Dir(statsFilePath) err = os.MkdirAll(statsDir, 0755) if err != nil { logger.Error("创建Shield统计数据目录失败", "dir", statsDir, "error", err) return } // 收集计数数据 m.rulesMutex.RLock() statsData := &ShieldStatsData{ BlockedDomainsCount: make(map[string]int), ResolvedDomainsCount: make(map[string]int), LastSaved: time.Now(), } // 复制数据 for k, v := range m.blockedDomainsCount { statsData.BlockedDomainsCount[k] = v } for k, v := range m.resolvedDomainsCount { statsData.ResolvedDomainsCount[k] = v } m.rulesMutex.RUnlock() // 序列化数据 jsonData, err := json.MarshalIndent(statsData, "", " ") if err != nil { logger.Error("序列化Shield计数数据失败", "error", err) return } // 使用临时文件先写入,然后重命名,避免文件损坏 tempFilePath := statsFilePath + ".tmp" err = ioutil.WriteFile(tempFilePath, jsonData, 0644) if err != nil { logger.Error("写入临时Shield统计文件失败", "file", tempFilePath, "error", err) return } // 原子操作重命名文件 err = os.Rename(tempFilePath, statsFilePath) if err != nil { logger.Error("重命名Shield统计文件失败", "temp", tempFilePath, "dest", statsFilePath, "error", err) // 尝试清理临时文件 os.Remove(tempFilePath) return } logger.Info("Shield计数数据保存成功", "file", statsFilePath, "blocked_entries", len(statsData.BlockedDomainsCount), "resolved_entries", len(statsData.ResolvedDomainsCount)) } // startAutoSaveStats 启动计数数据自动保存功能 func (m *ShieldManager) startAutoSaveStats() { if m.config.StatsFile == "" || m.config.StatsSaveInterval <= 0 { return } ticker := time.NewTicker(time.Duration(m.config.StatsSaveInterval) * time.Second) defer ticker.Stop() logger.Info("启动Shield计数数据自动保存功能", "interval", m.config.StatsSaveInterval, "file", m.config.StatsFile) // 定期保存数据 for { select { case <-ticker.C: m.saveStatsData() case <-m.updateCtx.Done(): return } } } // GetBlacklists 获取所有黑名单配置 func (m *ShieldManager) GetBlacklists() []config.BlacklistEntry { return m.config.Blacklists } // UpdateBlacklist 更新黑名单配置 func (m *ShieldManager) UpdateBlacklist(blacklists []config.BlacklistEntry) { m.config.Blacklists = blacklists } // GetAllHosts 获取所有hosts条目 func (m *ShieldManager) GetAllHosts() map[string]string { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() // 返回hostsMap的副本,避免并发问题 hostsCopy := make(map[string]string, len(m.hostsMap)) for domain, ip := range m.hostsMap { hostsCopy[domain] = ip } return hostsCopy } // GetHostsCount 获取hosts条目数量 func (m *ShieldManager) GetHostsCount() int { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() return len(m.hostsMap) } // GetRules 获取所有规则 func (m *ShieldManager) GetRules() map[string]interface{} { m.rulesMutex.RLock() defer m.rulesMutex.RUnlock() // 转换map和slice为字符串列表 domainRulesList := make([]string, 0, len(m.domainRules)) for domain := range m.domainRules { domainRulesList = append(domainRulesList, "||"+domain+"^") } domainExceptionsList := make([]string, 0, len(m.domainExceptions)) for domain := range m.domainExceptions { domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^") } // 获取正则规则原始字符串 regexRulesList := make([]string, 0, len(m.regexRules)) for _, re := range m.regexRules { regexRulesList = append(regexRulesList, re.original) } // 获取正则排除规则原始字符串 regexExceptionsList := make([]string, 0, len(m.regexExceptions)) for _, re := range m.regexExceptions { regexExceptionsList = append(regexExceptionsList, re.original) } // 获取hosts规则 hostsRulesList := make([]string, 0, len(m.hostsMap)) for domain, ip := range m.hostsMap { hostsRulesList = append(hostsRulesList, ip+"\t"+domain) } // 计算总规则数量 totalRulesCount := len(m.domainRules) + len(m.regexRules) return map[string]interface{}{ "domainRules": domainRulesList, "domainExceptions": domainExceptionsList, "regexRules": regexRulesList, "regexExceptions": regexExceptionsList, "hostsRules": hostsRulesList, "blacklists": m.config.Blacklists, "localRulesCount": m.localRulesCount, "remoteRulesCount": m.remoteRulesCount, "totalRulesCount": totalRulesCount, } }