diff --git a/http/server.go b/http/server.go index 2b50314..5b9dfc9 100644 --- a/http/server.go +++ b/http/server.go @@ -596,9 +596,15 @@ func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) { func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表 switch r.Method { case http.MethodGet: + // 检查是否需要返回完整规则列表 + if r.URL.Query().Get("all") == "true" { + // 返回完整规则数据 + rules := s.shieldManager.GetRules() + json.NewEncoder(w).Encode(rules) + return + } // 获取规则统计信息 stats := s.shieldManager.GetStats() shieldInfo := map[string]interface{}{ diff --git a/shield/manager.go b/shield/manager.go index 356cbd6..56513f1 100644 --- a/shield/manager.go +++ b/shield/manager.go @@ -19,12 +19,6 @@ import ( "dns-server/logger" ) -// regexRule 正则规则结构,包含编译后的表达式和原始字符串 -type regexRule struct { - pattern *regexp.Regexp - original string -} - // ShieldStatsData 用于持久化的Shield统计数据 type ShieldStatsData struct { BlockedDomainsCount map[string]int `json:"blockedDomainsCount"` @@ -32,45 +26,61 @@ type ShieldStatsData struct { LastSaved time.Time `json:"lastSaved"` } +// regexRule 正则规则结构,包含编译后的表达式和原始字符串 +type regexRule struct { + pattern *regexp.Regexp + original string + isLocal bool // 是否为本地规则 + source string // 规则来源 +} + // ShieldManager 屏蔽管理器 type ShieldManager struct { - config *config.ShieldConfig - domainRules map[string]bool - domainExceptions map[string]bool - regexRules []regexRule - regexExceptions []regexRule - hostsMap map[string]string - blockedDomainsCount map[string]int - resolvedDomainsCount map[string]int - rulesMutex sync.RWMutex - updateCtx context.Context - updateCancel context.CancelFunc - updateRunning bool - localRulesCount int // 本地规则数量 - remoteRulesCount int // 远程规则数量 + config *config.ShieldConfig + domainRules map[string]bool + domainExceptions map[string]bool + domainRulesIsLocal map[string]bool // 标记域名规则是否为本地规则 + domainExceptionsIsLocal map[string]bool // 标记域名排除规则是否为本地规则 + domainRulesSource map[string]string // 标记域名规则来源 + domainExceptionsSource map[string]string // 标记域名排除规则来源 + regexRules []regexRule + regexExceptions []regexRule + hostsMap map[string]string + blockedDomainsCount map[string]int + resolvedDomainsCount map[string]int + rulesMutex sync.RWMutex + updateCtx context.Context + updateCancel context.CancelFunc + updateRunning bool + localRulesCount int // 本地规则数量 + remoteRulesCount int // 远程规则数量 } // NewShieldManager 创建屏蔽管理器实例 func NewShieldManager(config *config.ShieldConfig) *ShieldManager { ctx, cancel := context.WithCancel(context.Background()) manager := &ShieldManager{ - config: config, - domainRules: make(map[string]bool), - domainExceptions: make(map[string]bool), - regexRules: []regexRule{}, - regexExceptions: []regexRule{}, - hostsMap: make(map[string]string), - blockedDomainsCount: make(map[string]int), - resolvedDomainsCount: make(map[string]int), - updateCtx: ctx, - updateCancel: cancel, - localRulesCount: 0, - remoteRulesCount: 0, + config: config, + domainRules: make(map[string]bool), + domainExceptions: make(map[string]bool), + domainRulesIsLocal: make(map[string]bool), + domainExceptionsIsLocal: make(map[string]bool), + domainRulesSource: make(map[string]string), + domainExceptionsSource: make(map[string]string), + regexRules: []regexRule{}, + regexExceptions: []regexRule{}, + hostsMap: make(map[string]string), + blockedDomainsCount: make(map[string]int), + resolvedDomainsCount: make(map[string]int), + updateCtx: ctx, + updateCancel: cancel, + localRulesCount: 0, + remoteRulesCount: 0, } - + // 加载已保存的计数数据 manager.loadStatsData() - + return manager } @@ -82,6 +92,10 @@ func (m *ShieldManager) LoadRules() error { // 清空现有规则 m.domainRules = make(map[string]bool) m.domainExceptions = make(map[string]bool) + m.domainRulesIsLocal = make(map[string]bool) + m.domainExceptionsIsLocal = make(map[string]bool) + m.domainRulesSource = make(map[string]string) + m.domainExceptionsSource = make(map[string]string) m.regexRules = []regexRule{} m.regexExceptions = []regexRule{} m.hostsMap = make(map[string]string) @@ -134,7 +148,7 @@ func (m *ShieldManager) loadLocalRules() error { if line == "" || strings.HasPrefix(line, "#") { continue } - m.parseRule(line) + m.parseRule(line, true, "本地规则") // 本地规则,isLocal=true,来源为"本地规则" } // 更新本地规则计数 @@ -187,11 +201,11 @@ func (m *ShieldManager) shouldUpdateCache(cacheFile string) bool { func (m *ShieldManager) fetchRemoteRules(url string) error { // 获取缓存文件路径 cacheFile := m.getCacheFilePath(url) - + // 尝试从缓存加载 hasLoadedFromCache := false if !m.shouldUpdateCache(cacheFile) { - if err := m.loadCachedRules(cacheFile); err == nil { + if err := m.loadCachedRules(cacheFile, url); err == nil { logger.Info("从缓存加载远程规则", "url", url) hasLoadedFromCache = true } @@ -236,14 +250,14 @@ func (m *ShieldManager) fetchRemoteRules(url string) error { if line == "" || strings.HasPrefix(line, "#") { continue } - m.parseRule(line) + m.parseRule(line, false, url) // 远程规则,isLocal=false,来源为URL } return nil } // loadCachedRules 从缓存文件加载规则 -func (m *ShieldManager) loadCachedRules(filePath string) error { +func (m *ShieldManager) loadCachedRules(filePath string, source string) error { file, err := os.Open(filePath) if err != nil { return err @@ -265,7 +279,7 @@ func (m *ShieldManager) loadCachedRules(filePath string) error { if line == "" || strings.HasPrefix(line, "#") { continue } - m.parseRule(line) + m.parseRule(line, false, source) // 远程规则,isLocal=false,来源为URL } // 更新远程规则计数 @@ -318,7 +332,7 @@ func (m *ShieldManager) loadHosts() error { } // parseRule 解析规则行 -func (m *ShieldManager) parseRule(line string) { +func (m *ShieldManager) parseRule(line string, isLocal bool, source string) { // 处理注释 if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" { return @@ -343,12 +357,12 @@ func (m *ShieldManager) parseRule(line string) { case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"): // AdGuardHome域名规则格式: ||example.com^ domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^") - m.addDomainRule(domain, !isException) + m.addDomainRule(domain, !isException, isLocal, source) case strings.HasPrefix(line, "||"): // 精确域名匹配规则 domain := strings.TrimPrefix(line, "||") - m.addDomainRule(domain, !isException) + m.addDomainRule(domain, !isException, isLocal, source) case strings.HasPrefix(line, "*"): // 通配符规则,转换为正则表达式 @@ -356,15 +370,18 @@ func (m *ShieldManager) parseRule(line string) { pattern = "^" + pattern + "$" if re, err := regexp.Compile(pattern); err == nil { // 保存原始规则字符串 - m.addRegexRule(re, line, !isException) + m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"): - // 正则表达式规则 - pattern := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/") - if re, err := regexp.Compile(pattern); err == nil { + // 关键字匹配规则:/keyword/ 格式,不区分大小写,字面量匹配特殊字符 + keyword := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/") + // 转义特殊字符,确保字面量匹配 + quotedKeyword := regexp.QuoteMeta(keyword) + // 编译为不区分大小写的正则表达式,匹配域名中任意位置 + if re, err := regexp.Compile("(?i)" + quotedKeyword); err == nil { // 保存原始规则字符串 - m.addRegexRule(re, line, !isException) + m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"): @@ -373,7 +390,7 @@ func (m *ShieldManager) parseRule(line string) { // 将URL模式转换为正则表达式 pattern := "^" + regexp.QuoteMeta(urlPattern) + "$" if re, err := regexp.Compile(pattern); err == nil { - m.addRegexRule(re, line, !isException) + m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasPrefix(line, "|"): @@ -381,7 +398,7 @@ func (m *ShieldManager) parseRule(line string) { urlPattern := strings.TrimPrefix(line, "|") pattern := "^" + regexp.QuoteMeta(urlPattern) if re, err := regexp.Compile(pattern); err == nil { - m.addRegexRule(re, line, !isException) + m.addRegexRule(re, line, !isException, isLocal, source) } case strings.HasSuffix(line, "|"): @@ -389,12 +406,12 @@ func (m *ShieldManager) parseRule(line string) { urlPattern := strings.TrimSuffix(line, "|") pattern := regexp.QuoteMeta(urlPattern) + "$" if re, err := regexp.Compile(pattern); err == nil { - m.addRegexRule(re, line, !isException) + m.addRegexRule(re, line, !isException, isLocal, source) } default: // 默认作为普通域名规则 - m.addDomainRule(line, !isException) + m.addDomainRule(line, !isException, isLocal, source) } } @@ -418,42 +435,98 @@ func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string { } // addDomainRule 添加域名规则,支持是否为阻止规则 -func (m *ShieldManager) addDomainRule(domain string, block bool) { +func (m *ShieldManager) addDomainRule(domain string, block bool, isLocal bool, source string) { if block { + // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 + if !isLocal { + if _, exists := m.domainRulesIsLocal[domain]; exists && m.domainRulesIsLocal[domain] { + // 已经存在本地规则,不覆盖 + return + } + } m.domainRules[domain] = true + m.domainRulesIsLocal[domain] = isLocal + m.domainRulesSource[domain] = source // 添加所有子域名的匹配支持 parts := strings.Split(domain, ".") if len(parts) > 1 { // 为二级域名和顶级域名添加规则 for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") + // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 + if !isLocal { + if _, exists := m.domainRulesIsLocal[subdomain]; exists && m.domainRulesIsLocal[subdomain] { + // 已经存在本地规则,不覆盖 + continue + } + } m.domainRules[subdomain] = true + m.domainRulesIsLocal[subdomain] = isLocal + m.domainRulesSource[subdomain] = source } } } else { // 添加到排除规则 + // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 + if !isLocal { + if _, exists := m.domainExceptionsIsLocal[domain]; exists && m.domainExceptionsIsLocal[domain] { + // 已经存在本地规则,不覆盖 + return + } + } m.domainExceptions[domain] = true + m.domainExceptionsIsLocal[domain] = isLocal + m.domainExceptionsSource[domain] = source // 为子域名也添加排除规则 parts := strings.Split(domain, ".") if len(parts) > 1 { for i := 0; i < len(parts)-1; i++ { subdomain := strings.Join(parts[i:], ".") + // 如果是远程规则,检查是否已经存在本地规则,如果存在则不覆盖 + if !isLocal { + if _, exists := m.domainExceptionsIsLocal[subdomain]; exists && m.domainExceptionsIsLocal[subdomain] { + // 已经存在本地规则,不覆盖 + continue + } + } m.domainExceptions[subdomain] = true + m.domainExceptionsIsLocal[subdomain] = isLocal + m.domainExceptionsSource[subdomain] = source } } } } // addRegexRule 添加正则表达式规则,支持是否为阻止规则 -func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool) { +func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool, isLocal bool, source string) { rule := regexRule{ pattern: re, original: original, + isLocal: isLocal, + source: source, } if block { + // 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加 + if !isLocal { + for _, existingRule := range m.regexRules { + if existingRule.original == original && existingRule.isLocal { + // 已经存在相同的本地规则,不添加 + return + } + } + } m.regexRules = append(m.regexRules, rule) } else { // 添加到排除规则 + // 如果是远程规则,检查是否已经存在相同的本地规则,如果存在则不添加 + if !isLocal { + for _, existingRule := range m.regexExceptions { + if existingRule.original == original && existingRule.isLocal { + // 已经存在相同的本地规则,不添加 + return + } + } + } m.regexExceptions = append(m.regexExceptions, rule) } } @@ -471,15 +544,16 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf } result := map[string]interface{}{ - "domain": domain, - "blocked": false, - "blockRule": "", - "blockRuleType": "", - "excluded": false, - "excludeRule": "", + "domain": domain, + "blocked": false, + "blockRule": "", + "blockRuleType": "", + "blocksource": "", + "excluded": false, + "excludeRule": "", "excludeRuleType": "", - "hasHosts": false, - "hostsIP": "", + "hasHosts": false, + "hostsIP": "", } // 检查hosts记录 @@ -493,6 +567,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["excluded"] = true result["excludeRule"] = domain result["excludeRuleType"] = "exact_domain" + result["blocksource"] = m.domainExceptionsSource[domain] return result } @@ -504,6 +579,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["excluded"] = true result["excludeRule"] = subdomain result["excludeRuleType"] = "subdomain" + result["blocksource"] = m.domainExceptionsSource[subdomain] return result } } @@ -514,6 +590,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["excluded"] = true result["excludeRule"] = re.original result["excludeRuleType"] = "regex" + result["blocksource"] = re.source return result } } @@ -524,6 +601,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["blocked"] = true result["blockRule"] = domain result["blockRuleType"] = "exact_domain" + result["blocksource"] = m.domainRulesSource[domain] return result } @@ -535,6 +613,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["blocked"] = true result["blockRule"] = subdomain result["blockRuleType"] = "subdomain" + result["blocksource"] = m.domainRulesSource[subdomain] return result } } @@ -545,6 +624,7 @@ func (m *ShieldManager) CheckDomainBlockDetails(domain string) map[string]interf result["blocked"] = true result["blockRule"] = re.original result["blockRuleType"] = "regex" + result["blocksource"] = re.source return result } } @@ -667,13 +747,13 @@ func (m *ShieldManager) GetHostsIP(domain string) (string, bool) { return ip, exists } -// AddRule 添加屏蔽规则 +// AddRule 添加屏蔽规则,用户添加的规则是本地规则 func (m *ShieldManager) AddRule(rule string) error { m.rulesMutex.Lock() defer m.rulesMutex.Unlock() - // 解析并添加规则到内存 - m.parseRule(rule) + // 解析并添加规则到内存,isLocal=true表示本地规则,来源为"本地规则" + m.parseRule(rule, true, "本地规则") // 持久化保存规则到文件 if m.config.LocalRulesFile != "" { @@ -843,7 +923,7 @@ func (m *ShieldManager) StartAutoUpdate() { } } }() - + logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval) // 如果是首次启动,先保存一次数据确保目录存在 @@ -859,28 +939,36 @@ func (m *ShieldManager) StopAutoUpdate() { logger.Info("规则自动更新已停止") } -// saveRulesToFile 保存规则到文件 +// saveRulesToFile 保存规则到文件,只保存本地规则 func (m *ShieldManager) saveRulesToFile() error { var rules []string - // 添加域名规则 - for domain := range m.domainRules { - rules = append(rules, "||"+domain) + // 添加本地域名规则 + for domain, isLocal := range m.domainRulesIsLocal { + if isLocal { + rules = append(rules, "||"+domain) + } } - // 添加正则表达式规则 + // 添加本地正则表达式规则 for _, re := range m.regexRules { - rules = append(rules, re.original) + if re.isLocal { + rules = append(rules, re.original) + } } - // 添加排除规则 - for domain := range m.domainExceptions { - rules = append(rules, "@@||"+domain) + // 添加本地排除规则 + for domain, isLocal := range m.domainExceptionsIsLocal { + if isLocal { + rules = append(rules, "@@||"+domain) + } } - // 添加正则表达式排除规则 + // 添加本地正则表达式排除规则 for _, re := range m.regexExceptions { - rules = append(rules, re.original) + if re.isLocal { + rules = append(rules, re.original) + } } // 写入文件 @@ -1033,18 +1121,18 @@ func (m *ShieldManager) loadStatsData() { if len(dataSample) > 50 { dataSample = dataSample[:50] + "..." } - logger.Error("解析Shield计数数据失败", - "file", statsFilePath, - "error", err, + logger.Error("解析Shield计数数据失败", + "file", statsFilePath, + "error", err, "data_length", len(data), "data_sample", dataSample) - + // 重置为默认空数据 m.rulesMutex.Lock() m.blockedDomainsCount = make(map[string]int) m.resolvedDomainsCount = make(map[string]int) m.rulesMutex.Unlock() - + // 尝试保存一个有效的空文件 m.saveStatsData() return diff --git a/shield/rule_test.go b/shield/rule_test.go new file mode 100644 index 0000000..ce10eaa --- /dev/null +++ b/shield/rule_test.go @@ -0,0 +1,62 @@ +package shield + +import ( + "testing" + + "dns-server/config" +) + +func TestRuleParsing(t *testing.T) { + // 创建一个简单的配置 + cfg := &config.ShieldConfig{ + LocalRulesFile: "", + RemoteRulesCacheDir: ".", + UpdateInterval: 3600, + StatsFile: "", + StatsSaveInterval: 300, + HostsFile: "", + Blacklists: []config.BlacklistEntry{}, + } + + // 测试规则 + testCases := []struct { + rule string + domain string + blocked bool + desc string + }{ + // 测试关键字匹配规则 + {"/ad.qq.com/", "ad.qq.com", true, "精确匹配"}, + {"/ad.qq.com/", "sub.ad.qq.com", true, "子域名包含匹配"}, + {"/ad/", "ad.example.com", true, "开头匹配"}, + {"/ad/", "example.ad.com", true, "中间匹配"}, + {"/ad/", "example.com.ad", true, "结尾匹配"}, + {"/AD/", "ad.example.com", true, "不区分大小写匹配"}, + {"/example.com/", "example.com", true, "特殊字符转义匹配"}, + {"/ad/", "example.com", false, "不包含关键字,不应匹配"}, + {"/test/", "example.com", false, "不同关键字,不应匹配"}, + + // 测试排除规则 + {"@@/ad/", "ad.example.com", false, "排除规则,不应匹配"}, + } + + // 运行测试 + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + // 为每个测试用例创建一个新的屏蔽管理器实例 + manager := NewShieldManager(cfg) + + // 添加规则 + manager.AddRule(tc.rule) + + // 检查域名是否被屏蔽 + result := manager.CheckDomainBlockDetails(tc.domain) + blocked := result["blocked"].(bool) + + // 验证结果 + if blocked != tc.blocked { + t.Errorf("Rule %q: Domain %q expected %t, got %t", tc.rule, tc.domain, tc.blocked, blocked) + } + }) + } +} diff --git a/static/index.html b/static/index.html index ced944e..1b08f82 100644 --- a/static/index.html +++ b/static/index.html @@ -687,23 +687,15 @@
-
-

本地规则管理

- -
+

本地规则管理

-