Files
dns-server/shield/manager.go
2025-11-23 18:21:29 +08:00

720 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package shield
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net/http"
"os"
"regexp"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
)
// regexRule 正则规则结构,包含编译后的表达式和原始字符串
type regexRule struct {
pattern *regexp.Regexp
original string
}
// ShieldManager 屏蔽管理器
type ShieldManager struct {
config *config.ShieldConfig
domainRules map[string]bool
domainExceptions map[string]bool
regexRules []regexRule
regexExceptions []regexRule
hostsMap map[string]string
rulesMutex sync.RWMutex
updateCtx context.Context
updateCancel context.CancelFunc
updateRunning bool
}
// NewShieldManager 创建屏蔽管理器实例
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
ctx, cancel := context.WithCancel(context.Background())
return &ShieldManager{
config: config,
domainRules: make(map[string]bool),
domainExceptions: make(map[string]bool),
regexRules: []regexRule{},
regexExceptions: []regexRule{},
hostsMap: make(map[string]string),
updateCtx: ctx,
updateCancel: cancel,
}
}
// LoadRules 加载屏蔽规则
func (m *ShieldManager) LoadRules() error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 清空现有规则
m.domainRules = make(map[string]bool)
m.domainExceptions = make(map[string]bool)
m.regexRules = []regexRule{}
m.regexExceptions = []regexRule{}
m.hostsMap = make(map[string]string)
// 加载本地规则文件
if err := m.loadLocalRules(); err != nil {
logger.Error("加载本地规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载远程规则
if err := m.loadRemoteRules(); err != nil {
logger.Error("加载远程规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载hosts文件
if err := m.loadHosts(); err != nil {
logger.Error("加载hosts文件失败", "error", err)
// 继续执行,不返回错误
}
logger.Info(fmt.Sprintf("规则加载完成,域名规则: %d, 排除规则: %d, 正则规则: %d, hosts规则: %d",
len(m.domainRules), len(m.domainExceptions), len(m.regexRules), len(m.hostsMap)))
return nil
}
// loadLocalRules 加载本地规则文件
func (m *ShieldManager) loadLocalRules() error {
if m.config.LocalRulesFile == "" {
return nil
}
file, err := os.Open(m.config.LocalRulesFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return scanner.Err()
}
// loadRemoteRules 加载远程规则
func (m *ShieldManager) loadRemoteRules() error {
for _, url := range m.config.RemoteRules {
if err := m.fetchRemoteRules(url); err != nil {
logger.Error("获取远程规则失败", "url", url, "error", err)
continue
}
}
return nil
}
// fetchRemoteRules 从远程URL获取规则
func (m *ShieldManager) fetchRemoteRules(url string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("远程服务器返回错误状态码: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return nil
}
// loadHosts 加载hosts文件
func (m *ShieldManager) loadHosts() error {
if m.config.HostsFile == "" {
return nil
}
file, err := os.Open(m.config.HostsFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
ip := parts[0]
for i := 1; i < len(parts); i++ {
m.hostsMap[parts[i]] = ip
}
}
}
return scanner.Err()
}
// parseRule 解析规则行
func (m *ShieldManager) parseRule(line string) {
// 处理注释
if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" {
return
}
// 移除规则选项部分(暂时不处理规则选项)
if strings.Contains(line, "$") {
parts := strings.SplitN(line, "$", 2)
line = parts[0]
// 规则选项暂时不处理
}
// 处理排除规则 (@@前缀表示取消屏蔽)
isException := false
if strings.HasPrefix(line, "@@") {
isException = true
line = strings.TrimPrefix(line, "@@")
}
// 处理不同类型的规则
switch {
case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"):
// AdGuardHome域名规则格式: ||example.com^
domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "||"):
// 精确域名匹配规则
domain := strings.TrimPrefix(line, "||")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "*"):
// 通配符规则,转换为正则表达式
pattern := strings.ReplaceAll(line, "*", ".*")
pattern = "^" + pattern + "$"
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"):
// 正则表达式规则
pattern := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/")
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"):
// 完整URL匹配规则
urlPattern := strings.TrimPrefix(strings.TrimSuffix(line, "|"), "|")
// 将URL模式转换为正则表达式
pattern := "^" + regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|"):
// URL开头匹配规则
urlPattern := strings.TrimPrefix(line, "|")
pattern := "^" + regexp.QuoteMeta(urlPattern)
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasSuffix(line, "|"):
// URL结尾匹配规则
urlPattern := strings.TrimSuffix(line, "|")
pattern := regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
default:
// 默认作为普通域名规则
m.addDomainRule(line, !isException)
}
}
// parseRuleOptions 解析规则选项
func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string {
options := make(map[string]string)
parts := strings.Split(optionsStr, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if strings.Contains(part, "=") {
kv := strings.SplitN(part, "=", 2)
options[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
} else {
options[part] = ""
}
}
return options
}
// addDomainRule 添加域名规则,支持是否为阻止规则
func (m *ShieldManager) addDomainRule(domain string, block bool) {
if block {
m.domainRules[domain] = true
// 添加所有子域名的匹配支持
parts := strings.Split(domain, ".")
if len(parts) > 1 {
// 为二级域名和顶级域名添加规则
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainRules[subdomain] = true
}
}
} else {
// 添加到排除规则
m.domainExceptions[domain] = true
// 为子域名也添加排除规则
parts := strings.Split(domain, ".")
if len(parts) > 1 {
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainExceptions[subdomain] = true
}
}
}
}
// addRegexRule 添加正则表达式规则,支持是否为阻止规则
func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool) {
rule := regexRule{
pattern: re,
original: original,
}
if block {
m.regexRules = append(m.regexRules, rule)
} else {
// 添加到排除规则
m.regexExceptions = append(m.regexExceptions, rule)
}
}
// IsBlocked 检查域名是否被屏蔽
func (m *ShieldManager) IsBlocked(domain string) bool {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 预处理域名,去除可能的端口号
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
domain = parts[0]
}
// 首先检查排除规则(优先级最高)
// 检查域名排除规则
if m.domainExceptions[domain] {
return false
}
// 检查子域名排除规则
parts := strings.Split(domain, ".")
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainExceptions[subdomain] {
return false
}
}
// 检查正则表达式排除规则
for _, re := range m.regexExceptions {
if re.pattern.MatchString(domain) {
return false
}
}
// 然后检查阻止规则
// 检查精确域名匹配
if m.domainRules[domain] {
return true
}
// 检查子域名匹配AdGuardHome风格
// 从最长的子域名开始匹配,确保优先级正确
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainRules[subdomain] {
return true
}
}
// 检查正则表达式匹配
for _, re := range m.regexRules {
if re.pattern.MatchString(domain) {
return true
}
}
return false
}
// GetHostsIP 获取hosts文件中的IP映射
func (m *ShieldManager) GetHostsIP(domain string) (string, bool) {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
ip, exists := m.hostsMap[domain]
return ip, exists
}
// AddRule 添加屏蔽规则
func (m *ShieldManager) AddRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 解析并添加规则到内存
m.parseRule(rule)
// 持久化保存规则到文件
if m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveRule 删除屏蔽规则
func (m *ShieldManager) RemoveRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
removed := false
// 清理规则,移除可能的修饰符
cleanRule := rule
// 移除规则结束符
cleanRule = strings.TrimSuffix(cleanRule, "^")
// 尝试多种可能的规则格式
formatsToTry := []string{cleanRule}
// 根据规则类型添加可能的格式变体
if strings.HasPrefix(cleanRule, "@@||") {
// 已有的排除规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "@@||"))
} else if strings.HasPrefix(cleanRule, "||") {
// 已有的域名规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "||"))
} else {
// 可能是裸域名,尝试添加前缀
formatsToTry = append(formatsToTry, "||"+cleanRule, "@@||"+cleanRule)
}
// 尝试所有可能的格式变体来删除规则
for _, format := range formatsToTry {
if removed {
break
}
// 尝试删除排除规则
if strings.HasPrefix(format, "@@||") {
domain := strings.TrimPrefix(format, "@@||")
if _, exists := m.domainExceptions[domain]; exists {
delete(m.domainExceptions, domain)
removed = true
break
}
} else if strings.HasPrefix(format, "||") {
// 尝试删除域名规则
domain := strings.TrimPrefix(format, "||")
if _, exists := m.domainRules[domain]; exists {
delete(m.domainRules, domain)
removed = true
break
}
} else {
// 尝试直接作为域名删除
if _, exists := m.domainRules[format]; exists {
delete(m.domainRules, format)
removed = true
break
}
if _, exists := m.domainExceptions[format]; exists {
delete(m.domainExceptions, format)
removed = true
break
}
}
}
// 处理正则表达式规则
if !removed && strings.HasPrefix(cleanRule, "/") && strings.HasSuffix(cleanRule, "/") {
pattern := strings.TrimPrefix(strings.TrimSuffix(cleanRule, "/"), "/")
// 检查是否在正则表达式规则中
newRegexRules := []regexRule{}
for _, re := range m.regexRules {
if re.pattern.String() != pattern {
newRegexRules = append(newRegexRules, re)
} else {
removed = true
}
}
m.regexRules = newRegexRules
// 如果没有从正则规则中找到,检查是否在正则排除规则中
if !removed {
newRegexExceptions := []regexRule{}
for _, re := range m.regexExceptions {
if re.pattern.String() != pattern {
newRegexExceptions = append(newRegexExceptions, re)
} else {
removed = true
}
}
m.regexExceptions = newRegexExceptions
}
}
// 处理通配符和URL匹配规则
if !removed && (strings.HasPrefix(cleanRule, "*") || strings.HasSuffix(cleanRule, "*") || strings.HasPrefix(cleanRule, "|") || strings.HasSuffix(cleanRule, "|")) {
// 遍历所有规则,找到匹配的规则进行删除
for domain := range m.domainRules {
if domain == cleanRule || domain == rule {
delete(m.domainRules, domain)
removed = true
break
}
}
if !removed {
for domain := range m.domainExceptions {
if domain == cleanRule || domain == rule {
delete(m.domainExceptions, domain)
removed = true
break
}
}
}
}
// 如果有规则被删除,持久化保存更改
if removed && m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// StartAutoUpdate 启动自动更新
func (m *ShieldManager) StartAutoUpdate() {
if m.updateRunning {
return
}
m.updateRunning = true
go func() {
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
logger.Info("开始自动更新规则")
if err := m.LoadRules(); err != nil {
logger.Error("自动更新规则失败", "error", err)
} else {
logger.Info("自动更新规则成功")
}
case <-m.updateCtx.Done():
m.updateRunning = false
return
}
}
}()
}
// StopAutoUpdate 停止自动更新
func (m *ShieldManager) StopAutoUpdate() {
m.updateCancel()
}
// saveRulesToFile 保存规则到文件
func (m *ShieldManager) saveRulesToFile() error {
var rules []string
// 添加域名规则
for domain := range m.domainRules {
rules = append(rules, "||"+domain)
}
// 添加正则表达式规则
for _, re := range m.regexRules {
rules = append(rules, re.original)
}
// 添加排除规则
for domain := range m.domainExceptions {
rules = append(rules, "@@||"+domain)
}
// 添加正则表达式排除规则
for _, re := range m.regexExceptions {
rules = append(rules, re.original)
}
// 写入文件
content := strings.Join(rules, "\n")
return ioutil.WriteFile(m.config.LocalRulesFile, []byte(content), 0644)
}
// AddHostsEntry 添加hosts条目
func (m *ShieldManager) AddHostsEntry(ip, domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.hostsMap[domain] = ip
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveHostsEntry 删除hosts条目
func (m *ShieldManager) RemoveHostsEntry(domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
if _, exists := m.hostsMap[domain]; exists {
delete(m.hostsMap, domain)
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
}
return nil
}
// saveHostsToFile 保存hosts到文件
func (m *ShieldManager) saveHostsToFile() error {
var lines []string
// 添加hosts头部注释
lines = append(lines, "# DNS Server Hosts File")
lines = append(lines, "# Generated by DNS Server")
lines = append(lines, "")
// 添加localhost条目如果不存在
if _, exists := m.hostsMap["localhost"]; !exists {
lines = append(lines, "127.0.0.1 localhost")
lines = append(lines, "::1 localhost")
lines = append(lines, "")
}
// 添加所有hosts条目
for domain, ip := range m.hostsMap {
lines = append(lines, ip+"\t"+domain)
}
// 写入文件
content := strings.Join(lines, "\n")
return ioutil.WriteFile(m.config.HostsFile, []byte(content), 0644)
}
// GetStats 获取规则统计信息
func (m *ShieldManager) GetStats() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
return map[string]interface{}{
"domainRules": len(m.domainRules),
"domainExceptions": len(m.domainExceptions),
"regexRules": len(m.regexRules),
"regexExceptions": len(m.regexExceptions),
"hostsRules": len(m.hostsMap),
"updateInterval": m.config.UpdateInterval,
}
}
// GetRules 获取所有规则的详细列表
func (m *ShieldManager) GetRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表
domainRulesList := make([]string, 0, len(m.domainRules))
for domain := range m.domainRules {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
domainExceptionsList := make([]string, 0, len(m.domainExceptions))
for domain := range m.domainExceptions {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
// 获取正则规则原始字符串
regexRulesList := make([]string, 0, len(m.regexRules))
for _, re := range m.regexRules {
regexRulesList = append(regexRulesList, re.original)
}
// 获取正则排除规则原始字符串
regexExceptionsList := make([]string, 0, len(m.regexExceptions))
for _, re := range m.regexExceptions {
regexExceptionsList = append(regexExceptionsList, re.original)
}
// 获取hosts规则
hostsRulesList := make([]string, 0, len(m.hostsMap))
for domain, ip := range m.hostsMap {
hostsRulesList = append(hostsRulesList, ip+"\t"+domain)
}
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"hostsRules": hostsRulesList,
}
}