新建DNS服务器

This commit is contained in:
Alex Yang
2025-11-23 18:21:29 +08:00
commit 0072e8a5c2
15 changed files with 5372 additions and 0 deletions

719
shield/manager.go Normal file
View File

@@ -0,0 +1,719 @@
package shield
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net/http"
"os"
"regexp"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
)
// regexRule 正则规则结构,包含编译后的表达式和原始字符串
type regexRule struct {
pattern *regexp.Regexp
original string
}
// ShieldManager 屏蔽管理器
type ShieldManager struct {
config *config.ShieldConfig
domainRules map[string]bool
domainExceptions map[string]bool
regexRules []regexRule
regexExceptions []regexRule
hostsMap map[string]string
rulesMutex sync.RWMutex
updateCtx context.Context
updateCancel context.CancelFunc
updateRunning bool
}
// NewShieldManager 创建屏蔽管理器实例
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
ctx, cancel := context.WithCancel(context.Background())
return &ShieldManager{
config: config,
domainRules: make(map[string]bool),
domainExceptions: make(map[string]bool),
regexRules: []regexRule{},
regexExceptions: []regexRule{},
hostsMap: make(map[string]string),
updateCtx: ctx,
updateCancel: cancel,
}
}
// LoadRules 加载屏蔽规则
func (m *ShieldManager) LoadRules() error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 清空现有规则
m.domainRules = make(map[string]bool)
m.domainExceptions = make(map[string]bool)
m.regexRules = []regexRule{}
m.regexExceptions = []regexRule{}
m.hostsMap = make(map[string]string)
// 加载本地规则文件
if err := m.loadLocalRules(); err != nil {
logger.Error("加载本地规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载远程规则
if err := m.loadRemoteRules(); err != nil {
logger.Error("加载远程规则失败", "error", err)
// 继续执行,不返回错误
}
// 加载hosts文件
if err := m.loadHosts(); err != nil {
logger.Error("加载hosts文件失败", "error", err)
// 继续执行,不返回错误
}
logger.Info(fmt.Sprintf("规则加载完成,域名规则: %d, 排除规则: %d, 正则规则: %d, hosts规则: %d",
len(m.domainRules), len(m.domainExceptions), len(m.regexRules), len(m.hostsMap)))
return nil
}
// loadLocalRules 加载本地规则文件
func (m *ShieldManager) loadLocalRules() error {
if m.config.LocalRulesFile == "" {
return nil
}
file, err := os.Open(m.config.LocalRulesFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return scanner.Err()
}
// loadRemoteRules 加载远程规则
func (m *ShieldManager) loadRemoteRules() error {
for _, url := range m.config.RemoteRules {
if err := m.fetchRemoteRules(url); err != nil {
logger.Error("获取远程规则失败", "url", url, "error", err)
continue
}
}
return nil
}
// fetchRemoteRules 从远程URL获取规则
func (m *ShieldManager) fetchRemoteRules(url string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("远程服务器返回错误状态码: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
lines := strings.Split(string(body), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
m.parseRule(line)
}
return nil
}
// loadHosts 加载hosts文件
func (m *ShieldManager) loadHosts() error {
if m.config.HostsFile == "" {
return nil
}
file, err := os.Open(m.config.HostsFile)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) >= 2 {
ip := parts[0]
for i := 1; i < len(parts); i++ {
m.hostsMap[parts[i]] = ip
}
}
}
return scanner.Err()
}
// parseRule 解析规则行
func (m *ShieldManager) parseRule(line string) {
// 处理注释
if strings.HasPrefix(line, "!") || strings.HasPrefix(line, "#") || line == "" {
return
}
// 移除规则选项部分(暂时不处理规则选项)
if strings.Contains(line, "$") {
parts := strings.SplitN(line, "$", 2)
line = parts[0]
// 规则选项暂时不处理
}
// 处理排除规则 (@@前缀表示取消屏蔽)
isException := false
if strings.HasPrefix(line, "@@") {
isException = true
line = strings.TrimPrefix(line, "@@")
}
// 处理不同类型的规则
switch {
case strings.HasPrefix(line, "||") && strings.HasSuffix(line, "^"):
// AdGuardHome域名规则格式: ||example.com^
domain := strings.TrimSuffix(strings.TrimPrefix(line, "||"), "^")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "||"):
// 精确域名匹配规则
domain := strings.TrimPrefix(line, "||")
m.addDomainRule(domain, !isException)
case strings.HasPrefix(line, "*"):
// 通配符规则,转换为正则表达式
pattern := strings.ReplaceAll(line, "*", ".*")
pattern = "^" + pattern + "$"
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "/") && strings.HasSuffix(line, "/"):
// 正则表达式规则
pattern := strings.TrimPrefix(strings.TrimSuffix(line, "/"), "/")
if re, err := regexp.Compile(pattern); err == nil {
// 保存原始规则字符串
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|") && strings.HasSuffix(line, "|"):
// 完整URL匹配规则
urlPattern := strings.TrimPrefix(strings.TrimSuffix(line, "|"), "|")
// 将URL模式转换为正则表达式
pattern := "^" + regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasPrefix(line, "|"):
// URL开头匹配规则
urlPattern := strings.TrimPrefix(line, "|")
pattern := "^" + regexp.QuoteMeta(urlPattern)
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
case strings.HasSuffix(line, "|"):
// URL结尾匹配规则
urlPattern := strings.TrimSuffix(line, "|")
pattern := regexp.QuoteMeta(urlPattern) + "$"
if re, err := regexp.Compile(pattern); err == nil {
m.addRegexRule(re, line, !isException)
}
default:
// 默认作为普通域名规则
m.addDomainRule(line, !isException)
}
}
// parseRuleOptions 解析规则选项
func (m *ShieldManager) parseRuleOptions(optionsStr string) map[string]string {
options := make(map[string]string)
parts := strings.Split(optionsStr, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if strings.Contains(part, "=") {
kv := strings.SplitN(part, "=", 2)
options[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
} else {
options[part] = ""
}
}
return options
}
// addDomainRule 添加域名规则,支持是否为阻止规则
func (m *ShieldManager) addDomainRule(domain string, block bool) {
if block {
m.domainRules[domain] = true
// 添加所有子域名的匹配支持
parts := strings.Split(domain, ".")
if len(parts) > 1 {
// 为二级域名和顶级域名添加规则
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainRules[subdomain] = true
}
}
} else {
// 添加到排除规则
m.domainExceptions[domain] = true
// 为子域名也添加排除规则
parts := strings.Split(domain, ".")
if len(parts) > 1 {
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
m.domainExceptions[subdomain] = true
}
}
}
}
// addRegexRule 添加正则表达式规则,支持是否为阻止规则
func (m *ShieldManager) addRegexRule(re *regexp.Regexp, original string, block bool) {
rule := regexRule{
pattern: re,
original: original,
}
if block {
m.regexRules = append(m.regexRules, rule)
} else {
// 添加到排除规则
m.regexExceptions = append(m.regexExceptions, rule)
}
}
// IsBlocked 检查域名是否被屏蔽
func (m *ShieldManager) IsBlocked(domain string) bool {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 预处理域名,去除可能的端口号
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
domain = parts[0]
}
// 首先检查排除规则(优先级最高)
// 检查域名排除规则
if m.domainExceptions[domain] {
return false
}
// 检查子域名排除规则
parts := strings.Split(domain, ".")
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainExceptions[subdomain] {
return false
}
}
// 检查正则表达式排除规则
for _, re := range m.regexExceptions {
if re.pattern.MatchString(domain) {
return false
}
}
// 然后检查阻止规则
// 检查精确域名匹配
if m.domainRules[domain] {
return true
}
// 检查子域名匹配AdGuardHome风格
// 从最长的子域名开始匹配,确保优先级正确
for i := 0; i < len(parts)-1; i++ {
subdomain := strings.Join(parts[i:], ".")
if m.domainRules[subdomain] {
return true
}
}
// 检查正则表达式匹配
for _, re := range m.regexRules {
if re.pattern.MatchString(domain) {
return true
}
}
return false
}
// GetHostsIP 获取hosts文件中的IP映射
func (m *ShieldManager) GetHostsIP(domain string) (string, bool) {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
ip, exists := m.hostsMap[domain]
return ip, exists
}
// AddRule 添加屏蔽规则
func (m *ShieldManager) AddRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
// 解析并添加规则到内存
m.parseRule(rule)
// 持久化保存规则到文件
if m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveRule 删除屏蔽规则
func (m *ShieldManager) RemoveRule(rule string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
removed := false
// 清理规则,移除可能的修饰符
cleanRule := rule
// 移除规则结束符
cleanRule = strings.TrimSuffix(cleanRule, "^")
// 尝试多种可能的规则格式
formatsToTry := []string{cleanRule}
// 根据规则类型添加可能的格式变体
if strings.HasPrefix(cleanRule, "@@||") {
// 已有的排除规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "@@||"))
} else if strings.HasPrefix(cleanRule, "||") {
// 已有的域名规则格式,也尝试去掉前缀
formatsToTry = append(formatsToTry, strings.TrimPrefix(cleanRule, "||"))
} else {
// 可能是裸域名,尝试添加前缀
formatsToTry = append(formatsToTry, "||"+cleanRule, "@@||"+cleanRule)
}
// 尝试所有可能的格式变体来删除规则
for _, format := range formatsToTry {
if removed {
break
}
// 尝试删除排除规则
if strings.HasPrefix(format, "@@||") {
domain := strings.TrimPrefix(format, "@@||")
if _, exists := m.domainExceptions[domain]; exists {
delete(m.domainExceptions, domain)
removed = true
break
}
} else if strings.HasPrefix(format, "||") {
// 尝试删除域名规则
domain := strings.TrimPrefix(format, "||")
if _, exists := m.domainRules[domain]; exists {
delete(m.domainRules, domain)
removed = true
break
}
} else {
// 尝试直接作为域名删除
if _, exists := m.domainRules[format]; exists {
delete(m.domainRules, format)
removed = true
break
}
if _, exists := m.domainExceptions[format]; exists {
delete(m.domainExceptions, format)
removed = true
break
}
}
}
// 处理正则表达式规则
if !removed && strings.HasPrefix(cleanRule, "/") && strings.HasSuffix(cleanRule, "/") {
pattern := strings.TrimPrefix(strings.TrimSuffix(cleanRule, "/"), "/")
// 检查是否在正则表达式规则中
newRegexRules := []regexRule{}
for _, re := range m.regexRules {
if re.pattern.String() != pattern {
newRegexRules = append(newRegexRules, re)
} else {
removed = true
}
}
m.regexRules = newRegexRules
// 如果没有从正则规则中找到,检查是否在正则排除规则中
if !removed {
newRegexExceptions := []regexRule{}
for _, re := range m.regexExceptions {
if re.pattern.String() != pattern {
newRegexExceptions = append(newRegexExceptions, re)
} else {
removed = true
}
}
m.regexExceptions = newRegexExceptions
}
}
// 处理通配符和URL匹配规则
if !removed && (strings.HasPrefix(cleanRule, "*") || strings.HasSuffix(cleanRule, "*") || strings.HasPrefix(cleanRule, "|") || strings.HasSuffix(cleanRule, "|")) {
// 遍历所有规则,找到匹配的规则进行删除
for domain := range m.domainRules {
if domain == cleanRule || domain == rule {
delete(m.domainRules, domain)
removed = true
break
}
}
if !removed {
for domain := range m.domainExceptions {
if domain == cleanRule || domain == rule {
delete(m.domainExceptions, domain)
removed = true
break
}
}
}
}
// 如果有规则被删除,持久化保存更改
if removed && m.config.LocalRulesFile != "" {
if err := m.saveRulesToFile(); err != nil {
logger.Error("保存规则到文件失败", "error", err)
return err
}
}
return nil
}
// StartAutoUpdate 启动自动更新
func (m *ShieldManager) StartAutoUpdate() {
if m.updateRunning {
return
}
m.updateRunning = true
go func() {
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
logger.Info("开始自动更新规则")
if err := m.LoadRules(); err != nil {
logger.Error("自动更新规则失败", "error", err)
} else {
logger.Info("自动更新规则成功")
}
case <-m.updateCtx.Done():
m.updateRunning = false
return
}
}
}()
}
// StopAutoUpdate 停止自动更新
func (m *ShieldManager) StopAutoUpdate() {
m.updateCancel()
}
// saveRulesToFile 保存规则到文件
func (m *ShieldManager) saveRulesToFile() error {
var rules []string
// 添加域名规则
for domain := range m.domainRules {
rules = append(rules, "||"+domain)
}
// 添加正则表达式规则
for _, re := range m.regexRules {
rules = append(rules, re.original)
}
// 添加排除规则
for domain := range m.domainExceptions {
rules = append(rules, "@@||"+domain)
}
// 添加正则表达式排除规则
for _, re := range m.regexExceptions {
rules = append(rules, re.original)
}
// 写入文件
content := strings.Join(rules, "\n")
return ioutil.WriteFile(m.config.LocalRulesFile, []byte(content), 0644)
}
// AddHostsEntry 添加hosts条目
func (m *ShieldManager) AddHostsEntry(ip, domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
m.hostsMap[domain] = ip
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
return nil
}
// RemoveHostsEntry 删除hosts条目
func (m *ShieldManager) RemoveHostsEntry(domain string) error {
m.rulesMutex.Lock()
defer m.rulesMutex.Unlock()
if _, exists := m.hostsMap[domain]; exists {
delete(m.hostsMap, domain)
// 持久化保存到hosts文件
if m.config.HostsFile != "" {
if err := m.saveHostsToFile(); err != nil {
logger.Error("保存hosts到文件失败", "error", err)
return err
}
}
}
return nil
}
// saveHostsToFile 保存hosts到文件
func (m *ShieldManager) saveHostsToFile() error {
var lines []string
// 添加hosts头部注释
lines = append(lines, "# DNS Server Hosts File")
lines = append(lines, "# Generated by DNS Server")
lines = append(lines, "")
// 添加localhost条目如果不存在
if _, exists := m.hostsMap["localhost"]; !exists {
lines = append(lines, "127.0.0.1 localhost")
lines = append(lines, "::1 localhost")
lines = append(lines, "")
}
// 添加所有hosts条目
for domain, ip := range m.hostsMap {
lines = append(lines, ip+"\t"+domain)
}
// 写入文件
content := strings.Join(lines, "\n")
return ioutil.WriteFile(m.config.HostsFile, []byte(content), 0644)
}
// GetStats 获取规则统计信息
func (m *ShieldManager) GetStats() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
return map[string]interface{}{
"domainRules": len(m.domainRules),
"domainExceptions": len(m.domainExceptions),
"regexRules": len(m.regexRules),
"regexExceptions": len(m.regexExceptions),
"hostsRules": len(m.hostsMap),
"updateInterval": m.config.UpdateInterval,
}
}
// GetRules 获取所有规则的详细列表
func (m *ShieldManager) GetRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()
// 转换map和slice为字符串列表
domainRulesList := make([]string, 0, len(m.domainRules))
for domain := range m.domainRules {
domainRulesList = append(domainRulesList, "||"+domain+"^")
}
domainExceptionsList := make([]string, 0, len(m.domainExceptions))
for domain := range m.domainExceptions {
domainExceptionsList = append(domainExceptionsList, "@@||"+domain+"^")
}
// 获取正则规则原始字符串
regexRulesList := make([]string, 0, len(m.regexRules))
for _, re := range m.regexRules {
regexRulesList = append(regexRulesList, re.original)
}
// 获取正则排除规则原始字符串
regexExceptionsList := make([]string, 0, len(m.regexExceptions))
for _, re := range m.regexExceptions {
regexExceptionsList = append(regexExceptionsList, re.original)
}
// 获取hosts规则
hostsRulesList := make([]string, 0, len(m.hostsMap))
for domain, ip := range m.hostsMap {
hostsRulesList = append(hostsRulesList, ip+"\t"+domain)
}
return map[string]interface{}{
"domainRules": domainRulesList,
"domainExceptions": domainExceptionsList,
"regexRules": regexRulesList,
"regexExceptions": regexExceptionsList,
"hostsRules": hostsRulesList,
}
}