diff --git a/dns-server b/dns-server index f6dcc91..6fa4c89 100755 Binary files a/dns-server and b/dns-server differ diff --git a/dns/server.go b/dns/server.go index cd253c8..3f52934 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1556,17 +1556,35 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg dnssecServerForResponse = resp.server } - // 如果响应成功或为NXDOMAIN + // 如果响应成功或为 NXDOMAIN if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError { - // 按Rcode分类添加到不同列表 + // 按 Rcode 分类添加到不同列表 if resp.response.Rcode == dns.RcodeSuccess { successResponses = append(successResponses, resp.response) } else { nxdomainResponses = append(nxdomainResponses, resp.response) } - // 快速返回逻辑:找到第一个有效响应或更快的响应 - if resp.response.Rcode == dns.RcodeSuccess { + // 简化的快速返回逻辑:找到第一个成功响应或更快的响应 + // 对于不验证 DNSSEC 的域名,直接返回第一个成功响应 + if noDNSSEC { + // 不验证 DNSSEC 的域名:直接返回第一个成功响应 + if !hasBestResponse || resp.rtt < fastestRtt { + fastestResponse = resp.response + fastestRtt = resp.rtt + fastestServer = resp.server + fastestDnssecServer = dnssecServerForResponse + fastestHasDnssec = false + + // 立即发送结果,快速返回 + resultChan <- struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} + } + } else if resp.response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC的响应 if containsDNSSEC { // 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快 diff --git a/dns/server.go.orig b/dns/server.go.orig new file mode 100644 index 0000000..dc0edbe --- /dev/null +++ b/dns/server.go.orig @@ -0,0 +1,3293 @@ +package dns + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "math" + "math/rand" + "net" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "dns-server/config" + "dns-server/gfw" + "dns-server/log" + "dns-server/logger" + "dns-server/shield" + "dns-server/threat" + + "github.com/miekg/dns" +) + +// 确保DNS服务器地址包含端口号,默认添加53端口 +func normalizeDNSServerAddress(address string) string { + // 检查地址是否已经包含端口号 + if _, _, err := net.SplitHostPort(address); err != nil { + // 如果没有端口号,添加默认的53端口 + return net.JoinHostPort(address, "53") + } + // 已经有端口号,直接返回 + return address +} + +// BlockedDomain 屏蔽域名统计 +type BlockedDomain struct { + Domain string + Count int64 + LastSeen int64 + DNSSEC bool // 是否使用了DNSSEC +} + +// ClientStats 客户端统计 + +type ClientStats struct { + IP string + Count int64 + LastSeen int64 +} + +// DNSAnswer DNS解析记录 +type DNSAnswer struct { + Type string `json:"type"` // 记录类型 + Value string `json:"value"` // 记录值 + TTL uint32 `json:"ttl"` // 生存时间 +} + +// QueryLog 查询日志记录 +type QueryLog struct { + Timestamp time.Time `json:"timestamp"` // 查询时间 + ClientIP string `json:"clientIP"` // 客户端IP + Location string `json:"location"` // IP地理位置(国家 城市) + Domain string `json:"domain"` // 查询域名 + QueryType string `json:"queryType"` // 查询类型 + ResponseTime int64 `json:"responseTime"` // 响应时间(ms) + Result string `json:"result"` // 查询结果(allowed, blocked, error) + BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽) + BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽) + FromCache bool `json:"fromCache"` // 是否来自缓存 + DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC + EDNS bool `json:"edns"` // 是否使用了EDNS + DNSServer string `json:"dnsServer"` // 使用的DNS服务器 + DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器 + Answers []DNSAnswer `json:"answers"` // 解析记录 + ResponseCode int `json:"responseCode"` // DNS响应代码 +} + +// StatsData 用于持久化的统计数据结构 +type StatsData struct { + Stats *Stats `json:"stats"` + BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"` + ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"` + ClientStats map[string]*ClientStats `json:"clientStats"` + HourlyStats map[string]int64 `json:"hourlyStats"` + DailyStats map[string]int64 `json:"dailyStats"` + MonthlyStats map[string]int64 `json:"monthlyStats"` + LastSaved time.Time `json:"lastSaved"` +} + +// ServerStats 服务器统计信息 +type ServerStats struct { + SuccessCount int64 // 成功查询次数 + FailureCount int64 // 失败查询次数 + LastResponse time.Time // 最后响应时间 + ResponseTime time.Duration // 平均响应时间 + ConnectionSpeed time.Duration // TCP连接速度 +} + +// Server DNS服务器 +type Server struct { + globalConfig *config.Config + config *config.DNSConfig + shieldConfig *config.ShieldConfig + shieldManager *shield.ShieldManager + gfwConfig *config.GFWListConfig + gfwManager *gfw.GFWListManager + server *dns.Server + tcpServer *dns.Server + resolver *dns.Client + ctx context.Context + cancel context.CancelFunc + statsMutex sync.Mutex + stats *Stats + blockedDomainsMutex sync.RWMutex + blockedDomains map[string]*BlockedDomain + resolvedDomainsMutex sync.RWMutex + resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名 + clientStatsMutex sync.RWMutex + clientStats map[string]*ClientStats // 用于记录客户端统计 + hourlyStatsMutex sync.RWMutex + hourlyStats map[string]int64 // 按小时统计屏蔽数量 + dailyStatsMutex sync.RWMutex + dailyStats map[string]int64 // 按天统计屏蔽数量 + monthlyStatsMutex sync.RWMutex + monthlyStats map[string]int64 // 按月统计屏蔽数量 + + // 新日志系统 + logManager *log.LogManager // 日志管理器 + archiveManager *log.ArchiveManager // 归档管理器 + archiveQueryEngine *log.ArchiveQueryEngine // 归档查询引擎 + + // DNS查询缓存 + DnsCache *DNSCache // DNS响应缓存 + + // 域名DNSSEC状态映射表 + domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射 + domainDNSSECStatusMutex sync.RWMutex // 保护域名DNSSEC状态映射的互斥锁 + + // 上游服务器状态跟踪 + serverStats map[string]*ServerStats // 服务器地址到状态的映射 + serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁 + + // DNSSEC专用服务器映射,用于快速查找 + dnssecServerMap map[string]bool // DNSSEC专用服务器地址到布尔值的映射 + + // DNS客户端实例池,用于并行查询 + clientPool sync.Pool // 存储*dns.Client实例 + + // 威胁检测相关 + threatEngine *threat.ThreatEngine + alertManager *threat.AlertManager + dbManager *threat.ThreatDatabaseManager +} + +// Stats DNS服务器统计信息 +type Stats struct { + Queries int64 + Blocked int64 + Allowed int64 + Errors int64 + LastQuery time.Time + AvgResponseTime float64 // 平均响应时间(ms) + TotalResponseTime int64 // 总响应时间 + QueryTypes map[string]int64 // 查询类型统计 + SourceIPs map[string]bool // 活跃来源IP + CpuUsage float64 // CPU使用率(%) + DNSSECQueries int64 // DNSSEC查询总数 + DNSSECSuccess int64 // DNSSEC验证成功数 + DNSSECFailed int64 // DNSSEC验证失败数 + DNSSECEnabled bool // 是否启用了DNSSEC +} + +// NewServer 创建DNS服务器实例 +func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server { + ctx, cancel := context.WithCancel(context.Background()) + + // 从配置中读取DNS缓存TTL值(分钟) + cacheTTL := time.Duration(globalConfig.DNS.CacheTTL) * time.Minute + // 保存间隔(秒) + saveInterval := time.Duration(globalConfig.DNS.SaveInterval) * time.Second + // 最大和最小缓存TTL(分钟) + maxCacheTTL := time.Duration(globalConfig.DNS.MaxCacheTTL) * time.Minute + minCacheTTL := time.Duration(globalConfig.DNS.MinCacheTTL) * time.Minute + + server := &Server{ + globalConfig: globalConfig, + config: &globalConfig.DNS, + shieldConfig: &globalConfig.Shield, + shieldManager: shieldManager, + gfwConfig: &globalConfig.GFWList, + gfwManager: gfwManager, + resolver: &dns.Client{ + Net: "udp", + UDPSize: 4096, // 增加UDP缓冲区大小,支持更大的DNSSEC响应 + }, + ctx: ctx, + cancel: cancel, + stats: &Stats{ + Queries: 0, + Blocked: 0, + Allowed: 0, + Errors: 0, + AvgResponseTime: 0, + TotalResponseTime: 0, + QueryTypes: make(map[string]int64), + SourceIPs: make(map[string]bool), + CpuUsage: 0, + DNSSECQueries: 0, + DNSSECSuccess: 0, + DNSSECFailed: 0, + DNSSECEnabled: globalConfig.DNS.EnableDNSSEC, + }, + blockedDomains: make(map[string]*BlockedDomain), + resolvedDomains: make(map[string]*BlockedDomain), + clientStats: make(map[string]*ClientStats), + hourlyStats: make(map[string]int64), + dailyStats: make(map[string]int64), + monthlyStats: make(map[string]int64), + + // DNS 查询缓存初始化 + DnsCache: NewDNSCache(cacheTTL, globalConfig.DNS.CacheMode, globalConfig.DNS.CacheSize, globalConfig.DNS.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL), + // 初始化域名 DNSSEC 状态映射表 + domainDNSSECStatus: make(map[string]bool), + // 初始化服务器状态跟踪 + serverStats: make(map[string]*ServerStats), + // 初始化 DNSSEC 专用服务器映射 + dnssecServerMap: make(map[string]bool), + // 初始化 DNS 客户端实例池 + clientPool: sync.Pool{ + New: func() interface{} { + return &dns.Client{ + Net: "udp", + UDPSize: 4096, + Timeout: 2 * time.Second, // 默认超时时间,会在使用时覆盖(2 秒是合理的 DNS 查询超时) + } + }, + }, + } + + // 初始化新日志系统 + logManager, err := log.NewLogManager(log.DefaultConfig()) + if err != nil { + logger.Error("初始化日志管理器失败", "error", err) + } else { + server.logManager = logManager + logger.Info("新日志系统初始化成功", "ringBufferSize", log.DefaultConfig().RingBufferSize, "databasePath", log.DefaultConfig().DatabasePath) + } + + // 初始化归档管理器 + if globalConfig.QueryLog.ArchiveEnabled { + // 转换为 log.QueryLogConfig + logConfig := &log.QueryLogConfig{ + Enabled: globalConfig.QueryLog.Enabled, + RingBufferSize: globalConfig.QueryLog.RingBufferSize, + DatabasePath: globalConfig.QueryLog.DatabasePath, + MaxDatabaseSizeMB: globalConfig.QueryLog.MaxDatabaseSizeMB, + EnableWAL: globalConfig.QueryLog.EnableWAL, + ArchiveEnabled: globalConfig.QueryLog.ArchiveEnabled, + ArchiveDir: globalConfig.QueryLog.ArchiveDir, + ArchivePrefix: globalConfig.QueryLog.ArchivePrefix, + CompressionLevel: globalConfig.QueryLog.CompressionLevel, + RetentionDays: globalConfig.QueryLog.RetentionDays, + RetentionMonths: globalConfig.QueryLog.RetentionMonths, + QueryTimeout: globalConfig.QueryLog.QueryTimeout, + EnableCache: globalConfig.QueryLog.EnableCache, + CacheTTL: globalConfig.QueryLog.CacheTTL, + } + + archiveManager, err := log.NewArchiveManager(logConfig, globalConfig.QueryLog.DatabasePath) + if err != nil { + logger.Error("初始化归档管理器失败", "error", err) + } else { + server.archiveManager = archiveManager + logger.Info("归档管理器初始化成功", "archiveDir", globalConfig.QueryLog.ArchiveDir) + } + + // 初始化归档查询引擎 + if logManager != nil { + sqliteStore := logManager.GetSQLiteStore() + if sqliteStore != nil { + archiveQueryEngine, err := log.NewArchiveQueryEngine(sqliteStore, archiveManager, logConfig) + if err != nil { + logger.Error("初始化归档查询引擎失败", "error", err) + } else { + server.archiveQueryEngine = archiveQueryEngine + logger.Info("归档查询引擎初始化成功") + } + } + } + } + + // 加载已保存的统计数据 + server.loadStatsData() + + return server + +} + +// initThreatDetection 初始化威胁检测相关组件 +func (s *Server) initThreatDetection() { + // 从全局配置中获取威胁检测配置 + threatConfig := &s.globalConfig.Threat + + // 创建告警管理器 + s.alertManager = threat.NewAlertManager(threatConfig) + + // 加载已保存的告警 + s.alertManager.LoadAlerts() + + // 创建威胁域名数据库管理器 + s.dbManager = threat.NewThreatDatabaseManager(threatConfig.ThreatDatabasePath) + + // 加载威胁域名数据库 + s.dbManager.LoadDatabase() + + // 启动文件监听,自动检测数据库变更 + if err := s.dbManager.StartWatching(); err != nil { + logger.Warn("启动威胁域名数据库监听失败", "error", err) + } + + // 创建威胁检测引擎 + s.threatEngine = threat.NewThreatEngine(threatConfig, s.alertManager, s.dbManager) +} + +// Start 启动DNS服务器 +func (s *Server) Start() error { + // 重新初始化上下文和取消函数 + ctx, cancel := context.WithCancel(context.Background()) + s.ctx = ctx + s.cancel = cancel + + // 重新初始化saveDone通道 + + // 重置stopped标志 + + // 初始化威胁检测相关组件 + s.initThreatDetection() + + s.server = &dns.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), + Net: "udp", + Handler: dns.HandlerFunc(s.handleDNSRequest), + } + + // 保存TCP服务器实例,以便在Stop方法中关闭 + s.tcpServer = &dns.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), + Net: "tcp", + Handler: dns.HandlerFunc(s.handleDNSRequest), + } + + // 启动CPU使用率监控 + go s.startCpuUsageMonitor() + + // 启动自动保存功能 + go s.startAutoSave() + + // 更新DNSSEC专用服务器映射 + s.updateDNSSECServerMap() + + // 启动日志处理协程(已移除,新日志系统使用 SQLite 存储) + // go s.processLogs() + + // 启动统计数据定期重置功能(每 24 小时) + go func() { + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.resetStats() + case <-s.ctx.Done(): + return + } + } + }() + + // 启动归档监控和清理任务 + if s.archiveManager != nil { + // 启动归档监控 + s.archiveManager.StartWatching() + + // 启动定期清理任务(每天执行) + go func() { + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if deleted, err := s.archiveManager.CleanupOldArchives(); err != nil { + logger.Error("清理归档失败", "error", err) + } else if deleted > 0 { + logger.Info("清理归档完成", "deleted", deleted) + } + case <-s.ctx.Done(): + return + } + } + }() + + logger.Info("归档监控和清理任务已启动") + } + + // 启动UDP服务 + go func() { + logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port)) + if err := s.server.ListenAndServe(); err != nil { + logger.Error("DNS UDP服务器启动失败", "error", err) + s.cancel() + } + }() + + // 启动TCP服务 + go func() { + logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port)) + if err := s.tcpServer.ListenAndServe(); err != nil { + logger.Error("DNS TCP服务器启动失败", "error", err) + s.cancel() + } + }() + + // 等待停止信号 + <-s.ctx.Done() + return nil +} + +// resetStats 重置统计数据 +func (s *Server) resetStats() { + s.statsMutex.Lock() + defer s.statsMutex.Unlock() + + // 只重置累计值,保留配置相关值 + s.stats.TotalResponseTime = 0 + s.stats.AvgResponseTime = 0 + s.stats.Queries = 0 + s.stats.Blocked = 0 + s.stats.Allowed = 0 + s.stats.Errors = 0 + s.stats.DNSSECQueries = 0 + s.stats.DNSSECSuccess = 0 + s.stats.DNSSECFailed = 0 + s.stats.QueryTypes = make(map[string]int64) + s.stats.SourceIPs = make(map[string]bool) + + logger.Info("统计数据已重置") +} + +// Stop 停止DNS服务器 +func (s *Server) Stop() { + // 检查服务器是否已经停止 + // 标记服务器为已停止状态 + + // 停止威胁域名数据库文件监听 + if s.dbManager != nil { + s.dbManager.StopWatching() + } + + // 发送停止信号给保存协程 + + // 最后保存一次数据 + s.saveStatsData() + + // 停止服务器 + s.cancel() + if s.server != nil { + s.server.Shutdown() + } + if s.tcpServer != nil { + s.tcpServer.Shutdown() + } + logger.Info("DNS服务器已停止") +} + +// handleDNSRequest 处理DNS请求 +func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { + startTime := time.Now() + + // 1. 初始化请求信息 + reqInfo := s.initRequestInfo(w, r) + + // 2. 检查基本请求条件 + if earlyResponse := s.checkRequestConditions(w, r, startTime, reqInfo); earlyResponse { + return + } + + // 3. 检查本地处理规则 + if localHandled := s.handleLocalRules(w, r, startTime, reqInfo); localHandled { + return + } + + // 4. 尝试从缓存获取响应 + if cacheHandled := s.handleCacheResponse(w, r, startTime, reqInfo); cacheHandled { + return + } + + // 5. 威胁检测 + s.checkThreatDetection(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType) + + // 6. 转发请求到上游服务器 + s.handleUpstreamRequest(w, r, startTime, reqInfo) +} + +// requestInfo 封装请求相关信息 +type requestInfo struct { + sourceIP string + domain string + queryType string + qType uint16 + queryAttempts []string +} + +// initRequestInfo 初始化请求信息 +func (s *Server) initRequestInfo(w dns.ResponseWriter, r *dns.Msg) *requestInfo { + // 获取来源IP + sourceIP := w.RemoteAddr().String() + // 提取IP地址部分,去掉端口 + if strings.HasPrefix(sourceIP, "[") { + // IPv6地址格式: [::1]:53 + if idx := strings.Index(sourceIP, "]"); idx >= 0 { + sourceIP = sourceIP[1:idx] // 去掉方括号 + } + } else { + // IPv4地址格式: 127.0.0.1:53 + if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 { + sourceIP = sourceIP[:idx] + } + } + + // 更新来源IP统计和Queries计数器 + s.updateStats(func(stats *Stats) { + stats.Queries++ + stats.LastQuery = time.Now() + stats.SourceIPs[sourceIP] = true + }) + + // 更新客户端统计 + s.updateClientStats(sourceIP) + + // 获取查询域名和类型 + var domain string + var queryType string + var qType uint16 + if len(r.Question) > 0 { + domain = r.Question[0].Name + // 移除末尾的点 + if len(domain) > 0 && domain[len(domain)-1] == '.' { + domain = domain[:len(domain)-1] + } + // 获取查询类型 + if t, ok := dns.TypeToString[r.Question[0].Qtype]; ok { + queryType = t + } else { + // 处理未知类型,使用数字表示 + queryType = fmt.Sprintf("TYPE%d", r.Question[0].Qtype) + } + qType = r.Question[0].Qtype + // 更新查询类型统计 + s.updateStats(func(stats *Stats) { + stats.QueryTypes[queryType]++ + }) + } + + logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr()) + + return &requestInfo{ + sourceIP: sourceIP, + domain: domain, + queryType: queryType, + qType: qType, + queryAttempts: []string{domain}, + } +} + +// checkRequestConditions 检查请求条件,返回是否需要提前响应 +func (s *Server) checkRequestConditions(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool { + // 检查是否是AAAA记录查询且IPv6解析已禁用 + if reqInfo.qType == dns.TypeAAAA && !s.config.EnableIPv6 { + // 返回空的成功响应,而不是NXDOMAIN + response := new(dns.Msg) + response.SetReply(r) + response.SetRcode(r, dns.RcodeSuccess) + w.WriteMsg(response) + + // 更新统计信息 - 视为正常解析 + responseTime := time.Since(startTime).Milliseconds() + s.updateStats(func(stats *Stats) { + stats.Allowed++ + stats.TotalResponseTime += responseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 添加查询日志 + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "allowed", "", "", false, false, true, "", "", nil, dns.RcodeSuccess) + logger.Debug("IPv6解析已禁用,返回空的成功响应", "domain", reqInfo.domain) + return true + } + + // 只处理递归查询 + if r.RecursionDesired == false { + response := new(dns.Msg) + response.SetReply(r) + // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 + response.SetRcode(r, dns.RcodeRefused) + w.WriteMsg(response) + + // 计算实际响应时间 + responseTime := time.Since(startTime).Milliseconds() + // 更新统计信息 - 视为错误 + s.updateStats(func(stats *Stats) { + stats.Errors++ + stats.TotalResponseTime += responseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 添加查询日志 + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused) + return true + } + + return false +} + +// handleLocalRules 处理本地规则(hosts文件、GFWList、屏蔽规则),返回是否已处理 +func (s *Server) handleLocalRules(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool { + // 本地规则匹配的响应时间极短,使用固定值1ms + const localResponseTime int64 = 1 + + // 检查 hosts 文件是否有匹配 + if ip, exists := s.shieldManager.GetHostsIP(reqInfo.domain); exists { + s.handleHostsResponse(w, r, ip) + // 使用固定的短响应时间 + s.updateStats(func(stats *Stats) { + stats.TotalResponseTime += localResponseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 添加查询日志 - hosts 文件匹配 + hostsAnswers := []DNSAnswer{{ + Type: "A", + Value: ip, + TTL: 300, + }} + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "allowed", "", "", false, false, true, "hosts", "无", hostsAnswers, dns.RcodeSuccess) + return true + } + + // 检查是否为GFWList域名(仅当GFWList功能启用时) + if s.gfwConfig.Enabled && s.gfwManager != nil && s.gfwManager.IsMatch(reqInfo.domain) { + s.handleGFWListResponse(w, r, reqInfo.domain) + // 使用固定的短响应时间 + s.updateStats(func(stats *Stats) { + stats.TotalResponseTime += localResponseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 添加查询日志 - GFWList域名 + gfwAnswers := []DNSAnswer{} + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "gfwlist", "", "", false, false, true, "GFWList", "无", gfwAnswers, dns.RcodeSuccess) + return true + } + + // 检查是否被屏蔽 + if s.shieldManager.IsBlocked(reqInfo.domain) { + // 获取屏蔽详情 + blockDetails := s.shieldManager.CheckDomainBlockDetails(reqInfo.domain) + blockRule, _ := blockDetails["blockRule"].(string) + blockType, _ := blockDetails["blockRuleType"].(string) + + s.handleBlockedResponse(w, r, reqInfo.domain) + // 使用固定的短响应时间 + s.updateStats(func(stats *Stats) { + stats.TotalResponseTime += localResponseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 添加查询日志 - 被屏蔽域名 + blockedAnswers := []DNSAnswer{} + // 根据屏蔽方法确定响应代码 + blockedRcode := dns.RcodeNameError // 默认NXDOMAIN + if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" { + blockedRcode = dns.RcodeRefused + } else if blockMethod == "emptyIP" || blockMethod == "customIP" { + blockedRcode = dns.RcodeSuccess + } + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode) + return true + } + + return false +} + +// handleCacheResponse 尝试从缓存获取响应,返回是否已处理 +func (s *Server) handleCacheResponse(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool { + // 检查缓存中是否有响应(优先查找带DNSSEC的缓存项) + var cachedResponse *dns.Msg + var found bool + var cachedDNSSEC bool + + // 1. 首先检查是否有普通缓存项 + if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, reqInfo.qType); tempFound { + cachedResponse = tempResponse + found = tempFound + cachedDNSSEC = s.hasDNSSECRecords(tempResponse) + } + + // 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项, + // 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录 + // (这里可以进一步优化,比如在缓存中标记DNSSEC状态,快速查找) + if s.config.EnableDNSSEC && !cachedDNSSEC { + // 目前的缓存实现不支持按DNSSEC状态查找,所以这里暂时跳过 + // 后续可以考虑改进缓存实现,添加DNSSEC状态标记 + } + + if !found { + return false + } + + // 缓存命中,直接返回缓存的响应 + cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题 + cachedResponseCopy.Id = r.Id // 更新ID以匹配请求 + cachedResponseCopy.Compress = true + + // 如果客户端请求包含EDNS记录,确保响应也包含EDNS + if opt := r.IsEdns0(); opt != nil { + // 检查响应是否已经包含EDNS记录 + if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil { + // 添加EDNS记录,使用客户端的UDP缓冲区大小 + cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) + } else { + // 确保响应的UDP缓冲区大小不超过客户端请求的大小 + if respOpt.UDPSize() > opt.UDPSize() { + // 移除现有的EDNS记录 + for i := range cachedResponseCopy.Extra { + if cachedResponseCopy.Extra[i] == respOpt { + cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...) + break + } + } + // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 + cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) + } + } + } + + // 确保响应的Question部分与客户端请求的Question部分匹配 + cachedResponseCopy.Question = r.Question + + // 修复:如果响应包含记录,确保Rcode为成功 + hasValidRecords := false + + // 检查Answer部分 + if len(cachedResponseCopy.Answer) > 0 { + hasValidRecords = true + } else if len(cachedResponseCopy.Ns) > 0 { + // 检查Ns部分 + hasValidRecords = true + } else if len(cachedResponseCopy.Extra) > 0 { + // 检查Extra部分,排除OPT记录 + for _, rr := range cachedResponseCopy.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + hasValidRecords = true + break + } + } + } + + if hasValidRecords { + cachedResponseCopy.Rcode = dns.RcodeSuccess + } + + w.WriteMsg(cachedResponseCopy) + + // 缓存命中的响应时间应该是极短的,使用固定值1ms而非实际处理时间 + const cacheResponseTime int64 = 1 + + // 缓存命中的响应视为正常解析 + s.updateStats(func(stats *Stats) { + stats.Allowed++ + stats.TotalResponseTime += cacheResponseTime + stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries) + }) + + // 如果缓存响应包含DNSSEC记录,更新DNSSEC查询计数 + if cachedDNSSEC { + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + // 缓存响应视为DNSSEC成功 + stats.DNSSECSuccess++ + }) + } + + // 从缓存响应中提取解析记录 + cachedAnswers := []DNSAnswer{} + if cachedResponse != nil { + for _, rr := range cachedResponse.Answer { + cachedAnswers = append(cachedAnswers, DNSAnswer{ + Type: dns.TypeToString[rr.Header().Rrtype], + Value: rr.String(), + TTL: rr.Header().Ttl, + }) + } + } + + // 添加查询日志 - 标记为缓存 + // 从缓存响应中获取响应代码 + cacheRcode := dns.RcodeSuccess // 默认成功 + if cachedResponse != nil { + cacheRcode = cachedResponse.Rcode + } + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, cacheResponseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode) + logger.Debug("从缓存返回DNS响应", "domain", reqInfo.domain, "type", reqInfo.queryType, "dnssec", cachedDNSSEC) + return true +} + +// handleUpstreamRequest 处理上游请求 +func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) { + logger.Debug("开始处理上游请求", "domain", reqInfo.domain, "type", reqInfo.queryType) + + // 缓存未命中,处理 DNS 请求 + var response *dns.Msg + var rtt time.Duration + var dnsServer string + var dnssecServer string + + // 直接查询原始域名 + response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, reqInfo.domain) + + logger.Debug("上游请求返回", "domain", reqInfo.domain, "response", response != nil, "rtt", rtt) + + if response != nil { + // 如果客户端请求包含EDNS记录,确保响应也包含EDNS + if opt := r.IsEdns0(); opt != nil { + // 检查响应是否已经包含EDNS记录 + if respOpt := response.IsEdns0(); respOpt == nil { + // 添加EDNS记录,使用客户端的UDP缓冲区大小 + response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) + } else { + // 确保响应的UDP缓冲区大小不超过客户端请求的大小 + if respOpt.UDPSize() > opt.UDPSize() { + // 移除现有的EDNS记录 + for i := range response.Extra { + if response.Extra[i] == respOpt { + response.Extra = append(response.Extra[:i], response.Extra[i+1:]...) + break + } + } + // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 + response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) + } + } + } + + // 确保响应的Question部分与客户端请求的Question部分匹配 + response.Question = r.Question + + // 设置递归可用标志(因为我们的 DNS 服务器支持递归查询) + response.RecursionAvailable = true + response.RecursionDesired = r.RecursionDesired + + // 修复:如果响应包含记录,确保Rcode为成功 + hasValidRecords := false + + // 检查Answer部分 + if len(response.Answer) > 0 { + hasValidRecords = true + } else if len(response.Ns) > 0 { + // 检查Ns部分 + hasValidRecords = true + } else if len(response.Extra) > 0 { + // 检查Extra部分,排除OPT记录 + for _, rr := range response.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + hasValidRecords = true + break + } + } + } + + if hasValidRecords { + response.Rcode = dns.RcodeSuccess + } + + // 写入响应给客户端 + w.WriteMsg(response) + } + + // 使用上游服务器的实际响应时间(转换为毫秒) + responseTime := int64(rtt.Milliseconds()) + // 如果rtt为0(查询失败),则使用本地计算的时间 + if responseTime == 0 { + responseTime = time.Since(startTime).Milliseconds() + } + + // 添加合理性检查,避免异常大的响应时间影响统计 + if responseTime > 60000 { // 超过60秒的响应时间视为异常 + responseTime = 60000 + } + + // 更新基本统计 + s.updateStats(func(stats *Stats) { + stats.TotalResponseTime += responseTime + // 添加防御性编程,确保Queries大于0 + if stats.Queries > 0 { + // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 + avg := float64(stats.TotalResponseTime) / float64(stats.Queries) + stats.AvgResponseTime = float64(math.Round(avg)) + // 限制平均响应时间的范围,避免显示异常大的值 + if stats.AvgResponseTime > 60000 { + stats.AvgResponseTime = 60000 + } + } + }) + + // 判断请求结果类型并更新相应统计 + resultType := "allowed" + if response == nil { + // 响应为nil,视为错误 + resultType = "error" + s.updateStats(func(stats *Stats) { + stats.Errors++ + }) + } else if response.Rcode != dns.RcodeSuccess { + // 响应代码不是成功,视为错误 + resultType = "error" + s.updateStats(func(stats *Stats) { + stats.Errors++ + }) + } else { + // 成功响应,视为正常解析 + resultType = "allowed" + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) + } + + // 检查响应是否包含DNSSEC记录并验证结果 + responseDNSSEC := false + if response != nil { + // 使用hasDNSSECRecords函数检查是否包含DNSSEC记录 + responseDNSSEC = s.hasDNSSECRecords(response) + + // 检查AD标志,确认DNSSEC验证是否成功 + if response.AuthenticatedData { + responseDNSSEC = true + } + + // 更新域名的DNSSEC状态 + if responseDNSSEC { + s.updateDomainDNSSECStatus(reqInfo.domain, true) + } + } + + // 如果响应成功,缓存结果(增强版缓存存储) + if response != nil && response.Rcode == dns.RcodeSuccess { + // 创建响应副本以避免后续修改影响缓存 + responseCopy := response.Copy() + // 设置合理的TTL,不超过默认的30分钟 + defaultCacheTTL := 30 * time.Minute + + // 1. 缓存原始域名的查询结果 + s.DnsCache.Set(r.Question[0].Name, reqInfo.qType, responseCopy, defaultCacheTTL) + logger.Debug("DNS响应已缓存", "domain", reqInfo.domain, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC) + + // 2. 如果响应包含CNAME记录,同时缓存CNAME指向的域名的查询结果 + for _, rr := range response.Answer { + if cname, ok := rr.(*dns.CNAME); ok { + // 为CNAME指向的域名创建缓存 + cnameQuery := r.Copy() + cnameQuery.Question[0].Name = cname.Target + s.DnsCache.Set(cname.Target, reqInfo.qType, responseCopy, defaultCacheTTL) + logger.Debug("CNAME响应已缓存", "domain", cname.Target, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC) + break + } + } + } + + // 从响应中提取解析记录 + responseAnswers := []DNSAnswer{} + if response != nil { + for _, rr := range response.Answer { + responseAnswers = append(responseAnswers, DNSAnswer{ + Type: dns.TypeToString[rr.Header().Rrtype], + Value: rr.String(), + TTL: rr.Header().Ttl, + }) + } + } + + // 从响应中获取响应代码 + realRcode := dns.RcodeSuccess // 默认成功 + if response != nil { + realRcode = response.Rcode + } + + logger.Debug("准备添加查询日志", "domain", reqInfo.domain, "result", resultType, "responseCode", realRcode) + + // 添加查询日志 + s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, resultType, "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode) + + logger.Debug("查询日志已添加", "domain", reqInfo.domain) +} + +// handleHostsResponse 处理 hosts 文件匹配的响应 +func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) { + response := new(dns.Msg) + response.SetReply(r) + // 不再硬编码 RecursionAvailable,使用默认值或上游返回的值 + + if len(r.Question) > 0 { + q := r.Question[0] + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + } + answer.A = net.ParseIP(ip) + response.Answer = append(response.Answer, answer) + } + + // 记录解析域名统计 + domain := "" + if len(r.Question) > 0 { + domain = r.Question[0].Name + if len(domain) > 0 && domain[len(domain)-1] == '.' { + domain = domain[:len(domain)-1] + } + s.updateResolvedDomainStats(domain) + } + + w.WriteMsg(response) + // 本地 hosts 匹配响应时间极短,使用固定值 1ms + const localResponseTime int64 = 1 + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) +} + +// handleGFWListResponse 处理GFWList域名响应 +func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { + logger.Info("GFWList域名解析", "domain", domain, "client", w.RemoteAddr(), "ip", s.gfwConfig.IP) + + // 更新解析域名统计 + s.updateResolvedDomainStats(domain) + + response := new(dns.Msg) + response.SetReply(r) + + if len(r.Question) > 0 { + q := r.Question[0] + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + } + answer.A = net.ParseIP(s.gfwConfig.IP) + response.Answer = append(response.Answer, answer) + } + + w.WriteMsg(response) + // GFWList域名匹配响应时间极短,使用固定值1ms + const localResponseTime int64 = 1 + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) +} + +// handleBlockedResponse 处理被屏蔽的域名响应 +func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { + logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr()) + + // 更新被屏蔽域名统计 + s.updateBlockedDomainStats(domain) + + response := new(dns.Msg) + response.SetReply(r) + // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 + + // 获取屏蔽方法配置 + blockMethod := "NXDOMAIN" // 默认值 + customBlockIP := "" // 默认值 + + // 从Server结构体的shieldConfig字段获取配置 + if s.shieldConfig != nil { + blockMethod = s.shieldConfig.BlockMethod + customBlockIP = s.shieldConfig.CustomBlockIP + } + + // 根据屏蔽方法返回不同的响应 + switch blockMethod { + case "refused": + // 返回拒绝查询响应 + response.SetRcode(r, dns.RcodeRefused) + case "emptyIP": + // 返回空IP响应 + if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: r.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + } + answer.A = net.ParseIP("0.0.0.0") // 空IP + response.Answer = append(response.Answer, answer) + } + case "customIP": + // 返回自定义IP响应 + if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: r.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + } + // 使用自定义屏蔽IP + if customBlockIP != "" { + answer.A = net.ParseIP(customBlockIP) + } else { + // 如果没有配置,使用0.0.0.0 + answer.A = net.ParseIP("0.0.0.0") + } + response.Answer = append(response.Answer, answer) + } + case "NXDOMAIN", "": + fallthrough // 默认使用NXDOMAIN + default: + // 返回NXDOMAIN响应(域名不存在) + response.SetRcode(r, dns.RcodeNameError) + } + + w.WriteMsg(response) + // 屏蔽规则匹配响应时间极短,使用固定值1ms + const localResponseTime int64 = 1 + s.updateStats(func(stats *Stats) { + stats.Blocked++ + }) +} + +// forwardDNSRequest 转发DNS请求到上游服务器 +// serverResponse 用于存储服务器响应的结构体 +type serverResponse struct { + response *dns.Msg + rtt time.Duration + server string + error error +} + +// updateDNSSECServerMap 更新DNSSEC专用服务器映射,用于快速查找 +func (s *Server) updateDNSSECServerMap() { + // 清空现有映射 + for k := range s.dnssecServerMap { + delete(s.dnssecServerMap, k) + } + + // 添加所有DNSSEC专用服务器到映射 + for _, server := range s.config.DNSSECUpstreamDNS { + s.dnssecServerMap[server] = true + } +} + +// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应 +func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) { + // 始终支持EDNS + var udpSize uint16 = 4096 + var doFlag bool = s.config.EnableDNSSEC + + // 检查域名是否匹配不验证DNSSEC的模式 + noDNSSEC := false + for _, pattern := range s.config.NoDNSSECDomains { + if strings.Contains(domain, pattern) { + noDNSSEC = true + doFlag = false + logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern) + break + } + } + + // 检查客户端请求是否包含EDNS记录 + if opt := r.IsEdns0(); opt != nil { + // 保留客户端的UDP缓冲区大小 + udpSize = opt.UDPSize() + // 移除现有的EDNS记录,以便重新添加 + for i := range r.Extra { + if r.Extra[i] == opt { + r.Extra = append(r.Extra[:i], r.Extra[i+1:]...) + break + } + } + } + + // 添加EDNS记录,设置适当的UDPSize和DO标志 + r.SetEdns0(udpSize, doFlag) + + // DNSSEC专用服务器列表,从配置中获取 + dnssecServers := s.config.DNSSECUpstreamDNS + + // 选择合适的上游DNS服务器列表 + // 1. 首先检查是否有域名特定的DNS服务器配置 + var selectedUpstreamDNS []string + var domainMatched bool + + for matchStr, dnsServers := range s.config.DomainSpecificDNS { + if strings.Contains(domain, matchStr) { + selectedUpstreamDNS = dnsServers + domainMatched = true + logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers) + break + } + } + + // 2. 如果没有匹配的域名特定配置 + if !domainMatched { + // 创建一个新的切片来存储最终的上游服务器列表 + var finalUpstreamDNS []string + + // 首先添加用户配置的上游DNS服务器 + finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...) + logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS) + + // 如果启用了DNSSEC且有配置DNSSEC专用服务器,并且域名不匹配NoDNSSECDomains,则将DNSSEC专用服务器添加到列表中 + if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC { + // 合并DNSSEC专用服务器到上游服务器列表,避免重复,并确保包含端口号 + for _, dnssecServer := range s.config.DNSSECUpstreamDNS { + hasDuplicate := false + // 确保DNSSEC服务器地址包含端口号 + normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer) + for _, upstream := range finalUpstreamDNS { + if upstream == normalizedDnssecServer { + hasDuplicate = true + break + } + } + if !hasDuplicate { + finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer) + } + } + logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS) + } + + // 使用最终合并后的服务器列表 + selectedUpstreamDNS = finalUpstreamDNS + } + + // 1. 首先尝试所有配置的上游DNS服务器 + var bestResponse *dns.Msg + var bestRtt time.Duration + var hasBestResponse bool + var hasDNSSECResponse bool + var backupResponse *dns.Msg + var backupRtt time.Duration + var hasBackup bool + var usedDNSServer string + var usedDNSSECServer string + + // 使用配置中的超时时间 + defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond + + // 根据查询模式处理请求 + switch s.config.QueryMode { + case "parallel": + // 并行请求模式 - 返回第一个成功响应 + responses := make(chan serverResponse, len(selectedUpstreamDNS)) + var wg sync.WaitGroup + + // 向所有上游服务器并行发送请求 + for _, upstream := range selectedUpstreamDNS { + wg.Add(1) + go func(server string) { + defer wg.Done() + + // 从池中获取客户端实例 + client := s.clientPool.Get().(*dns.Client) + // 设置客户端参数(确保在 Exchange 之前设置,避免竞态条件) + client.Net = s.resolver.Net + client.UDPSize = s.resolver.UDPSize + client.Timeout = defaultTimeout // 使用配置的超时时间 + + // 发送请求并获取响应,确保服务器地址包含端口号 + response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server)) + responses <- serverResponse{response, rtt, server, err} + + // 将客户端实例放回池中(不重置 Timeout,因为下次使用时会重新设置) + s.clientPool.Put(client) + }(upstream) + } + + // 等待所有请求完成 + go func() { + wg.Wait() + close(responses) + }() + + // 处理响应,只返回第一个成功响应 + var lastErrorResponse *dns.Msg + var lastErrorRtt time.Duration + var lastErrorServer string + + for i := 0; i < len(selectedUpstreamDNS); i++ { + resp := <-responses + if resp.error == nil && resp.response != nil { + // 更新服务器统计信息 + s.updateServerStats(resp.server, true, resp.rtt) + + // 检查是否包含DNSSEC记录 + containsDNSSEC := s.hasDNSSECRecords(resp.response) + + // 对于不验证DNSSEC的域名,始终设置AD标志为false + if noDNSSEC { + resp.response.AuthenticatedData = false + } + + // 检查当前服务器是否是DNSSEC专用服务器 + if _, isDNSSECServer := s.dnssecServerMap[resp.server]; isDNSSECServer { + usedDNSSECServer = resp.server + } + + // 如果是成功响应,立即返回 + if resp.response.Rcode == dns.RcodeSuccess { + // 验证DNSSEC记录(如果需要) + if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(resp.response) + resp.response.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + bestResponse = resp.response + bestRtt = resp.rtt + usedDNSServer = resp.server + hasBestResponse = true + hasDNSSECResponse = containsDNSSEC + logger.Debug("返回第一个成功响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) + return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer + } else { + // 保存最后一个错误响应 + lastErrorResponse = resp.response + lastErrorRtt = resp.rtt + lastErrorServer = resp.server + } + } else { + // 更新服务器统计信息(失败) + s.updateServerStats(resp.server, false, 0) + } + } + + // 如果所有服务器都失败,返回最后一个错误 + if lastErrorResponse != nil { + bestResponse = lastErrorResponse + bestRtt = lastErrorRtt + usedDNSServer = lastErrorServer + hasBestResponse = true + logger.Debug("所有服务器都失败,返回最后一个错误响应", "domain", domain, "server", lastErrorServer) + } + + return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer + + case "fastest-ip": + // 最快的IP地址模式 - 通过ping测试选择最快服务器,只向一个服务器发送请求 + // 1. 选择最快的服务器 + fastestServer := s.selectFastestServer(selectedUpstreamDNS) + if fastestServer != "" { + // 从池中获取客户端实例 + client := s.clientPool.Get().(*dns.Client) + // 设置客户端参数 + client.Net = s.resolver.Net + client.UDPSize = s.resolver.UDPSize + client.Timeout = defaultTimeout + + // 只向一个服务器发送请求 + response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(fastestServer)) + + // 将客户端实例放回池中 + s.clientPool.Put(client) + + if err == nil && response != nil { + // 更新服务器统计信息 + s.updateServerStats(fastestServer, true, rtt) + + // 检查是否包含DNSSEC记录 + containsDNSSEC := s.hasDNSSECRecords(response) + + // 对于不验证DNSSEC的域名,始终设置AD标志为false + if noDNSSEC { + response.AuthenticatedData = false + } + + // 验证DNSSEC记录(如果需要) + if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(response) + response.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + // 检查响应是否包含有效的记录,如果包含,将Rcode设置为成功 + hasValidRecords := false + if len(response.Answer) > 0 { + hasValidRecords = true + } else if len(response.Ns) > 0 { + hasValidRecords = true + } else if len(response.Extra) > 0 { + for _, rr := range response.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + hasValidRecords = true + break + } + } + } + + if hasValidRecords { + response.Rcode = dns.RcodeSuccess + } + + // 设置最佳响应 + bestResponse = response + bestRtt = rtt + hasBestResponse = true + usedDNSServer = fastestServer + if containsDNSSEC { + hasDNSSECResponse = true + } + if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer { + usedDNSSECServer = fastestServer + } + logger.Debug("使用最快服务器返回响应", "domain", domain, "server", fastestServer, "rtt", rtt) + } else { + // 更新服务器统计信息(失败) + s.updateServerStats(fastestServer, false, 0) + logger.Debug("最快服务器请求失败", "domain", domain, "server", fastestServer, "error", err) + } + } + + default: + // 默认使用并行请求模式 - 实现快速返回和超时机制 + responses := make(chan serverResponse, len(selectedUpstreamDNS)) + resultChan := make(chan struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }, 1) + var wg sync.WaitGroup + + // 向所有上游服务器并行发送请求 + for _, upstream := range selectedUpstreamDNS { + wg.Add(1) + go func(server string) { + defer wg.Done() + + // 创建带有超时的resolver + client := &dns.Client{ + Net: s.resolver.Net, + UDPSize: s.resolver.UDPSize, + Timeout: defaultTimeout, + } + + // 发送请求并获取响应,确保服务器地址包含端口号 + response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server)) + responses <- serverResponse{response, rtt, server, err} + }(upstream) + } + + // 处理响应的协程 + go func() { + var fastestResponse *dns.Msg + var fastestRtt time.Duration = defaultTimeout + var fastestServer string + var fastestDnssecServer string + var fastestHasDnssec bool + var successResponses []*dns.Msg + var nxdomainResponses []*dns.Msg + + // 等待所有请求完成或超时 + timer := time.NewTimer(defaultTimeout) + defer timer.Stop() + + // 处理所有响应 + for { + select { + case resp, ok := <-responses: + if !ok { + // 所有响应都已处理 + goto doneProcessing + } + + if resp.error == nil && resp.response != nil { + // 更新服务器统计信息 + s.updateServerStats(resp.server, true, resp.rtt) + + // 检查是否包含DNSSEC记录 + containsDNSSEC := s.hasDNSSECRecords(resp.response) + + // 对于不验证DNSSEC的域名,始终设置AD标志为false + if noDNSSEC { + resp.response.AuthenticatedData = false + } + + dnssecServerForResponse := "" + if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(resp.server)]; isDNSSECServer { + dnssecServerForResponse = resp.server + } + + // 如果响应成功或为NXDOMAIN + if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError { + // 按Rcode分类添加到不同列表 + if resp.response.Rcode == dns.RcodeSuccess { + successResponses = append(successResponses, resp.response) + } else { + nxdomainResponses = append(nxdomainResponses, resp.response) + } + + // 快速返回逻辑:找到第一个有效响应或更快的响应 + if resp.response.Rcode == dns.RcodeSuccess { + // 优先选择带有DNSSEC的响应 + if containsDNSSEC { + // 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快 + if !fastestHasDnssec || resp.rtt < fastestRtt { + fastestResponse = resp.response + fastestRtt = resp.rtt + fastestServer = resp.server + fastestDnssecServer = dnssecServerForResponse + fastestHasDnssec = true + + // 只对将要返回的响应进行DNSSEC验证 + if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(fastestResponse) + + // 设置AD标志(Authenticated Data) + fastestResponse.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + // 发送结果,快速返回 + resultChan <- struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} + } + } else { + // 非DNSSEC响应,只有在还没有找到DNSSEC响应且当前响应更快时才更新 + if !fastestHasDnssec && resp.rtt < fastestRtt { + fastestResponse = resp.response + fastestRtt = resp.rtt + fastestServer = resp.server + fastestDnssecServer = dnssecServerForResponse + + // 检查是否包含DNSSEC记录 + respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse) + + // 只对将要返回的响应进行DNSSEC验证 + if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC { + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(fastestResponse) + + // 设置AD标志(Authenticated Data) + fastestResponse.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + // 发送结果,快速返回 + resultChan <- struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} + } + } + } else if resp.response.Rcode == dns.RcodeNameError { + // NXDOMAIN响应,只有在还没有找到响应或当前响应更快时才更新 + if !fastestHasDnssec && resp.rtt < fastestRtt { + fastestResponse = resp.response + fastestRtt = resp.rtt + fastestServer = resp.server + fastestDnssecServer = dnssecServerForResponse + + // 检查是否包含DNSSEC记录 + respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse) + + // 只对将要返回的响应进行DNSSEC验证 + if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC { + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(fastestResponse) + + // 设置AD标志(Authenticated Data) + fastestResponse.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + // 发送结果,快速返回 + resultChan <- struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} + } + } + } else { + // 更新备选响应,确保总有一个可用的响应 + if resp.response != nil { + if !hasBackup { + // 第一次保存备选响应 + backupResponse = resp.response + backupRtt = resp.rtt + hasBackup = true + } + } + } + } else { + // 更新服务器统计信息(失败) + s.updateServerStats(resp.server, false, 0) + } + case <-timer.C: + // 超时,停止等待更多响应 + goto doneProcessing + } + } + + doneProcessing: + // 如果还没有发送结果,发送最快的响应 + if fastestResponse != nil { + resultChan <- struct { + response *dns.Msg + rtt time.Duration + usedServer string + usedDnssecServer string + }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} + } + close(resultChan) + }() + + // 等待所有请求完成(不阻塞主流程) + go func() { + wg.Wait() + close(responses) + }() + + // 等待结果或超时 + select { + case result := <-resultChan: + // 快速返回结果 + bestResponse = result.response + bestRtt = result.rtt + usedDNSServer = result.usedServer + usedDNSSECServer = result.usedDnssecServer + hasBestResponse = true + hasDNSSECResponse = s.hasDNSSECRecords(result.response) + logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse) + case <-time.After(defaultTimeout): + // 超时,使用备选响应 + logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout) + } + } + + // 2. 当启用DNSSEC且没有找到带DNSSEC的响应时,向DNSSEC专用服务器发送请求 + // 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains,则不使用DNSSEC专用服务器,只使用指定的DNS服务器 + if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC { + logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain) + + // 增加DNSSEC查询计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + }) + + // 无论查询模式是什么,DNSSEC验证都只使用加权随机选择一个服务器 + selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers) + if selectedDnssecServer != "" { + // 使用带超时的方式执行Exchange + resultChan := make(chan struct { + response *dns.Msg + rtt time.Duration + err error + }, 1) + + go func() { + // 创建带有超时的resolver + client := &dns.Client{ + Net: s.resolver.Net, + UDPSize: s.resolver.UDPSize, + Timeout: defaultTimeout, + } + response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer)) + resultChan <- struct { + response *dns.Msg + rtt time.Duration + err error + }{response, rtt, err} + }() + + var response *dns.Msg + var rtt time.Duration + var err error + + // 使用超时获取结果 + select { + case result := <-resultChan: + response, rtt, err = result.response, result.rtt, result.err + case <-time.After(defaultTimeout): + // 超时,不再等待 + logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout) + return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer + } + + if err == nil && response != nil { + // 更新服务器统计信息 + s.updateServerStats(selectedDnssecServer, true, rtt) + + // 检查是否包含DNSSEC记录 + containsDNSSEC := s.hasDNSSECRecords(response) + + if response.Rcode == dns.RcodeSuccess { + // 无论响应是否包含DNSSEC记录,只要使用了DNSSEC专用服务器,就设置usedDNSSECServer + usedDNSSECServer = selectedDnssecServer + + // 验证DNSSEC记录 + signatureValid := s.verifyDNSSEC(response) + + // 设置AD标志(Authenticated Data) + response.AuthenticatedData = signatureValid + + if signatureValid { + // 更新DNSSEC验证成功计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECSuccess++ + }) + } else { + // 更新DNSSEC验证失败计数 + s.updateStats(func(stats *Stats) { + stats.DNSSECFailed++ + }) + } + + // 优先使用DNSSEC专用服务器的响应,尤其是带有DNSSEC记录的 + if containsDNSSEC { + bestResponse = response + bestRtt = rtt + hasBestResponse = true + hasDNSSECResponse = true + logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应,优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt) + } + // 注意:如果DNSSEC专用服务器返回的响应不包含DNSSEC记录, + // 我们不会覆盖之前从upstreamDNS获取的响应, + // 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求 + + // 更新备选响应 + if !hasBackup { + backupResponse = response + backupRtt = rtt + hasBackup = true + } + } + } else { + // 更新服务器统计信息(失败) + s.updateServerStats(selectedDnssecServer, false, 0) + } + } + } + + // 3. 返回最佳响应 + if hasBestResponse { + // 检查最佳响应是否包含DNSSEC记录 + bestHasDNSSEC := s.hasDNSSECRecords(bestResponse) + + // 如果启用了DNSSEC且最佳响应不包含DNSSEC记录,尝试使用本地解析(使用upstreamDNS服务器) + // 但如果域名匹配了domainSpecificDNS配置,则不执行此逻辑,只使用指定的DNS服务器 + if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched { + logger.Debug("最佳响应不包含DNSSEC记录,尝试使用本地解析(upstreamDNS)", "domain", domain) + // 选择一个upstreamDNS服务器进行解析(使用加权随机算法) + localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS) + if localServer != "" { + // 使用带超时的方式执行Exchange + resultChan := make(chan struct { + response *dns.Msg + rtt time.Duration + err error + }, 1) + + go func() { + // 创建临时的 resolver,设置超时时间 + tempResolver := &dns.Client{ + Net: s.resolver.Net, + UDPSize: s.resolver.UDPSize, + Timeout: defaultTimeout, // 使用配置的超时时间 + } + resp, rtt, e := tempResolver.Exchange(r, normalizeDNSServerAddress(localServer)) + resultChan <- struct { + response *dns.Msg + rtt time.Duration + err error + }{resp, rtt, e} + }() + + var localResponse *dns.Msg + var rtt time.Duration + var err error + + // 使用超时获取结果 + select { + case result := <-resultChan: + localResponse, rtt, err = result.response, result.rtt, result.err + case <-time.After(defaultTimeout): + // 超时 + logger.Debug("本地解析超时", "domain", domain, "server", localServer, "timeout", defaultTimeout) + // 超时后跳过本地解析 + localResponse = nil + err = fmt.Errorf("timeout") + } + + if err == nil && localResponse != nil { + // 更新服务器统计信息 + s.updateServerStats(localServer, true, rtt) + + // 检查是否包含DNSSEC记录 + localHasDNSSEC := s.hasDNSSECRecords(localResponse) + + // 验证DNSSEC记录(如果存在),但不影响最终响应 + if localHasDNSSEC { + signatureValid := s.verifyDNSSEC(localResponse) + localResponse.AuthenticatedData = signatureValid + + if signatureValid { + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECSuccess++ + }) + } else { + s.updateStats(func(stats *Stats) { + stats.DNSSECQueries++ + stats.DNSSECFailed++ + }) + } + } + + // 记录解析域名统计 + s.updateResolvedDomainStats(domain) + + // 更新域名的DNSSEC状态 + s.updateDomainDNSSECStatus(domain, localHasDNSSEC) + + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) + + logger.Debug("使用本地解析结果(upstreamDNS)", "domain", domain, "server", localServer, "rtt", rtt) + return localResponse, rtt, localServer, "" + } else { + // 更新服务器统计信息(失败) + s.updateServerStats(localServer, false, 0) + } + } + } + + // 记录解析域名统计 + s.updateResolvedDomainStats(domain) + + // 更新域名的DNSSEC状态 + if bestHasDNSSEC { + s.updateDomainDNSSECStatus(domain, true) + } else { + s.updateDomainDNSSECStatus(domain, false) + } + + // 检查响应是否包含CNAME记录,需要确保返回完整的解析链 + if bestResponse != nil && bestResponse.Rcode == dns.RcodeSuccess { + // 处理多级CNAME,直到获取到最终的A/AAAA记录 + maxCNAMELevels := 5 // 限制最大CNAME解析级数,防止循环解析 + currentLevel := 0 + + // 循环处理CNAME记录 + for currentLevel < maxCNAMELevels { + // 检查是否包含CNAME记录 + var hasCNAME bool + var cnameTarget string + + // 检查Answer部分,查找CNAME记录 + for _, rr := range bestResponse.Answer { + if cname, ok := rr.(*dns.CNAME); ok { + hasCNAME = true + cnameTarget = cname.Target + } + } + + // 如果不包含CNAME记录,或者已经包含最终的A/AAAA记录,退出循环 + var hasFinalRecord bool + for _, rr := range bestResponse.Answer { + switch rr.Header().Rrtype { + case dns.TypeA, dns.TypeAAAA: + hasFinalRecord = true + break + } + } + + if !hasCNAME || hasFinalRecord { + break // 没有CNAME记录,或者已经有最终记录,退出循环 + } + + // 如果包含CNAME记录但没有最终IP,继续查询 + logger.Debug("响应包含CNAME但没有最终IP,继续查询", "domain", domain, "cname", cnameTarget, "level", currentLevel) + + // 创建新的查询请求,查询CNAME指向的域名 + cnameQuery := r.Copy() + cnameQuery.Question[0].Name = cnameTarget + + // 继续查询CNAME指向的域名 + cnameResponse, _, cnameDnsServer, cnameDnssecServer := s.forwardDNSRequestWithCache(cnameQuery, cnameTarget) + if cnameResponse != nil && cnameResponse.Rcode == dns.RcodeSuccess { + // 合并CNAME响应的Answer部分到主响应 + bestResponse.Answer = append(bestResponse.Answer, cnameResponse.Answer...) + // 合并CNAME响应的Ns部分到主响应 + bestResponse.Ns = append(bestResponse.Ns, cnameResponse.Ns...) + // 合并CNAME响应的Extra部分到主响应,排除OPT记录 + for _, rr := range cnameResponse.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + bestResponse.Extra = append(bestResponse.Extra, rr) + } + } + // 更新使用的DNS服务器信息 + if cnameDnsServer != "" { + usedDNSServer = cnameDnsServer + } + if cnameDnssecServer != "" { + usedDNSSECServer = cnameDnssecServer + } + } else { + // 查询失败,退出循环 + break + } + + // 增加CNAME解析级数 + currentLevel++ + } + + if currentLevel >= maxCNAMELevels { + logger.Warn("CNAME解析级数超过限制,可能存在循环解析", "domain", domain, "maxLevels", maxCNAMELevels) + } + } + + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) + return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer + } + + // 如果有备选响应,返回该响应 + if hasBackup { + logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain) + // 记录解析域名统计 + s.updateResolvedDomainStats(domain) + // 更新统计信息 + s.updateStats(func(stats *Stats) { + stats.Allowed++ + }) + return backupResponse, backupRtt, "", "" + } + + // 所有上游服务器都失败,返回服务器失败错误 + response := new(dns.Msg) + response.SetReply(r) + + response.SetRcode(r, dns.RcodeServerFailure) + + logger.Error("DNS查询失败", "domain", domain) + s.updateStats(func(stats *Stats) { + stats.Errors++ + }) + return response, 0, "", "" +} + +// forwardDNSRequest 转发DNS请求到上游服务器 +func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) { + response, _, _, _ := s.forwardDNSRequestWithCache(r, domain) + w.WriteMsg(response) +} + +// updateBlockedDomainStats 更新被屏蔽域名统计 +func (s *Server) updateBlockedDomainStats(domain string) { + // 先尝试读锁,检查条目是否存在 + s.blockedDomainsMutex.RLock() + entry, exists := s.blockedDomains[domain] + s.blockedDomainsMutex.RUnlock() + + if exists { + // 使用原子操作更新计数和时间戳 + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + // 获取写锁,创建新条目 + s.blockedDomainsMutex.Lock() + // 再次检查,避免竞态条件 + if entry, exists := s.blockedDomains[domain]; exists { + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + s.blockedDomains[domain] = &BlockedDomain{ + Domain: domain, + Count: 1, + LastSeen: time.Now().UnixNano(), + DNSSEC: false, + } + } + s.blockedDomainsMutex.Unlock() + } + + // 更新统计数据 + now := time.Now() + + // 更新小时统计 + hourKey := now.Format("2006-01-02-15") + s.hourlyStatsMutex.Lock() + s.hourlyStats[hourKey]++ + s.hourlyStatsMutex.Unlock() + + // 更新每日统计 + dayKey := now.Format("2006-01-02") + s.dailyStatsMutex.Lock() + s.dailyStats[dayKey]++ + s.dailyStatsMutex.Unlock() + + // 更新每月统计 + monthKey := now.Format("2006-01") + s.monthlyStatsMutex.Lock() + s.monthlyStats[monthKey]++ + s.monthlyStatsMutex.Unlock() +} + +// updateClientStats 更新客户端统计 +func (s *Server) updateClientStats(ip string) { + // 先尝试读锁,检查条目是否存在 + s.clientStatsMutex.RLock() + entry, exists := s.clientStats[ip] + s.clientStatsMutex.RUnlock() + + if exists { + // 使用原子操作更新计数和时间戳 + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + // 获取写锁,创建新条目 + s.clientStatsMutex.Lock() + // 再次检查,避免竞态条件 + if entry, exists := s.clientStats[ip]; exists { + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + s.clientStats[ip] = &ClientStats{ + IP: ip, + Count: 1, + LastSeen: time.Now().UnixNano(), + } + } + s.clientStatsMutex.Unlock() + } +} + +// hasDNSSECRecords 检查响应是否包含DNSSEC记录 +func (s *Server) hasDNSSECRecords(response *dns.Msg) bool { + // 直接调用包内的hasDNSSECRecords函数,避免重复代码 + return hasDNSSECRecords(response) +} + +// verifyDNSSEC 验证DNSSEC签名 +func (s *Server) verifyDNSSEC(response *dns.Msg) bool { + // 提取DNSKEY和RRSIG记录,并按类型和名称组织记录 + dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY + rrsigs := make([]*dns.RRSIG, 0) + // 按 (名称, 类型) 组织记录集,用于快速查找 + rrSets := make(map[string]map[uint16][]dns.RR) // name -> type -> records + + // 定义处理单个记录的函数 + processRecord := func(rr dns.RR) { + num := rr.Header().Rrtype + name := rr.Header().Name + + // 组织记录集 + if _, exists := rrSets[name]; !exists { + rrSets[name] = make(map[uint16][]dns.RR) + } + if _, exists := rrSets[name][num]; !exists { + rrSets[name][num] = make([]dns.RR, 0) + } + rrSets[name][num] = append(rrSets[name][num], rr) + + // 特别处理DNSKEY和RRSIG + if dnskey, ok := rr.(*dns.DNSKEY); ok { + tag := dnskey.KeyTag() + dnskeys[tag] = dnskey + } else if rrsig, ok := rr.(*dns.RRSIG); ok { + rrsigs = append(rrsigs, rrsig) + } + } + + // 一次遍历所有响应部分,同时完成记录收集和组织 + for _, rr := range response.Answer { + processRecord(rr) + } + for _, rr := range response.Ns { + processRecord(rr) + } + for _, rr := range response.Extra { + processRecord(rr) + } + + // 如果没有RRSIG记录,验证失败 + if len(rrsigs) == 0 { + return false + } + + // 验证所有RRSIG记录 + signatureValid := true + // 用于记录已经警告过的DNSKEY tag,避免重复警告 + warnedKeyTags := make(map[uint16]bool) + for _, rrsig := range rrsigs { + // 查找对应的DNSKEY + dnskey, exists := dnskeys[rrsig.KeyTag] + if !exists { + // 仅当该key_tag尚未警告过时才记录警告 + if !warnedKeyTags[rrsig.KeyTag] { + logger.Warn("DNSSEC验证失败:找不到对应的DNSKEY", "key_tag", rrsig.KeyTag) + warnedKeyTags[rrsig.KeyTag] = true + } + signatureValid = false + continue + } + + // 快速查找需要验证的记录集 + name := rrsig.Header().Name + typeCovered := rrsig.TypeCovered + rrset := rrSets[name][typeCovered] + + // 验证签名 + if len(rrset) > 0 { + err := rrsig.Verify(dnskey, rrset) + if err != nil { + logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag) + signatureValid = false + } else { + logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag) + } + } + } + + return signatureValid +} + +// updateDomainDNSSECStatus 更新域名的DNSSEC状态 +func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) { + // 确保域名是小写 + domain = strings.ToLower(domain) + + // 更新域名的DNSSEC状态 + s.resolvedDomainsMutex.Lock() + defer s.resolvedDomainsMutex.Unlock() + + // 更新resolvedDomains中的DNSSEC状态 + if entry, exists := s.resolvedDomains[domain]; exists { + entry.DNSSEC = dnssec + } else { + s.resolvedDomains[domain] = &BlockedDomain{ + Domain: domain, + Count: 1, + LastSeen: time.Now().UnixNano(), + DNSSEC: dnssec, + } + } + + // 更新domainDNSSECStatus映射(使用单独的锁) + s.domainDNSSECStatusMutex.Lock() + s.domainDNSSECStatus[domain] = dnssec + s.domainDNSSECStatusMutex.Unlock() +} + +// updateResolvedDomainStats 更新解析域名统计 +func (s *Server) updateResolvedDomainStats(domain string) { + // 先尝试读锁,检查条目是否存在 + s.resolvedDomainsMutex.RLock() + entry, exists := s.resolvedDomains[domain] + s.resolvedDomainsMutex.RUnlock() + + if exists { + // 使用原子操作更新计数和时间戳 + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + // 获取写锁,创建新条目 + s.resolvedDomainsMutex.Lock() + // 再次检查,避免竞态条件 + if entry, exists := s.resolvedDomains[domain]; exists { + atomic.AddInt64(&entry.Count, 1) + atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) + } else { + s.resolvedDomains[domain] = &BlockedDomain{ + Domain: domain, + Count: 1, + LastSeen: time.Now().UnixNano(), + DNSSEC: false, + } + } + s.resolvedDomainsMutex.Unlock() + } +} + +// getServerStats 获取服务器统计信息,如果不存在则创建 +func (s *Server) getServerStats(server string) *ServerStats { + s.serverStatsMutex.RLock() + stats, exists := s.serverStats[server] + s.serverStatsMutex.RUnlock() + + if exists { + return stats + } + + s.serverStatsMutex.Lock() + defer s.serverStatsMutex.Unlock() + + if stats, exists := s.serverStats[server]; exists { + return stats + } + + stats = &ServerStats{ + SuccessCount: 0, + FailureCount: 0, + LastResponse: time.Now(), + ResponseTime: 0, + ConnectionSpeed: 0, + } + + s.serverStats[server] = stats + return stats +} + +// updateServerStats 更新服务器统计信息 +func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) { + stats := s.getServerStats(server) + + // 使用原子操作更新成功和失败计数 + if success { + successCount := atomic.AddInt64(&stats.SuccessCount, 1) + + // 只在需要更新平均响应时间时获取锁 + s.serverStatsMutex.Lock() + stats.LastResponse = time.Now() + + // 更新平均响应时间(简单移动平均) + if successCount == 1 { + // 第一次成功,直接使用当前响应时间 + stats.ResponseTime = rtt + } else { + // 使用纳秒进行计算以避免类型不匹配 + prevTotal := stats.ResponseTime.Nanoseconds() * (successCount - 1) + newTotal := prevTotal + rtt.Nanoseconds() + stats.ResponseTime = time.Duration(newTotal / successCount) + } + s.serverStatsMutex.Unlock() + } else { + atomic.AddInt64(&stats.FailureCount, 1) + + // 只更新LastResponse时获取锁 + s.serverStatsMutex.Lock() + stats.LastResponse = time.Now() + s.serverStatsMutex.Unlock() + } +} + +// selectWeightedRandomServer 加权随机选择服务器 +func (s *Server) selectWeightedRandomServer(servers []string) string { + if len(servers) == 0 { + return "" + } + + if len(servers) == 1 { + return servers[0] + } + + type serverWeight struct { + server string + weight int64 + responseTime time.Duration + successCount int64 + failureCount int64 + } + + var totalWeight int64 + var totalResponseTime time.Duration + var validServers int + var currentWeight int64 + + serversInfo := make([]serverWeight, len(servers)) + + for i, server := range servers { + stats := s.getServerStats(server) + + serversInfo[i] = serverWeight{ + server: server, + responseTime: stats.ResponseTime, + successCount: atomic.LoadInt64(&stats.SuccessCount), + failureCount: atomic.LoadInt64(&stats.FailureCount), + } + + if stats.ResponseTime > 0 { + totalResponseTime += stats.ResponseTime + validServers++ + } + } + + var avgResponseTime time.Duration + if validServers > 0 { + avgResponseTime = totalResponseTime / time.Duration(validServers) + } else { + avgResponseTime = 1 * time.Second + } + + var randomGen = rand.New(rand.NewSource(time.Now().UnixNano())) + + for i := range serversInfo { + baseWeight := serversInfo[i].successCount - serversInfo[i].failureCount*2 + if baseWeight < 1 { + baseWeight = 1 + } + + var responseFactor int64 = 100 + if serversInfo[i].responseTime > 0 { + if serversInfo[i].responseTime < avgResponseTime { + factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds() + if factor > 200 { + factor = 200 + } + responseFactor = factor + } else { + factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds() + if factor < 50 { + factor = 50 + } + responseFactor = factor + } + } + + finalWeight := (baseWeight * responseFactor) / 100 + if finalWeight < 1 { + finalWeight = 1 + } + + serversInfo[i].weight = finalWeight + totalWeight += finalWeight + } + + random := randomGen.Int63n(totalWeight) + + for _, sw := range serversInfo { + currentWeight += sw.weight + if random < currentWeight { + return sw.server + } + } + + // 兜底返回第一个服务器 + return servers[0] +} + +// measureServerSpeed 测量服务器TCP连接速度 +func (s *Server) measureServerSpeed(server string) time.Duration { + addr := server + if !strings.Contains(server, ":") { + addr = server + ":53" + } + + startTime := time.Now() + conn, err := net.DialTimeout("tcp", addr, 2*time.Second) + if err != nil { + return 2 * time.Second + } + defer conn.Close() + + connTime := time.Since(startTime) + + stats := s.getServerStats(server) + s.serverStatsMutex.Lock() + stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4 + s.serverStatsMutex.Unlock() + + return connTime +} + +// selectFastestServer 选择连接速度最快的服务器 +func (s *Server) selectFastestServer(servers []string) string { + if len(servers) == 0 { + return "" + } + + if len(servers) == 1 { + return servers[0] + } + + // 并行测量所有服务器的速度 + type speedResult struct { + server string + speed time.Duration + } + + results := make(chan speedResult, len(servers)) + var wg sync.WaitGroup + + for _, server := range servers { + wg.Add(1) + go func(srv string) { + defer wg.Done() + speed := s.measureServerSpeed(srv) + results <- speedResult{srv, speed} + }(server) + } + + // 等待所有测量完成 + go func() { + wg.Wait() + close(results) + }() + + // 找出最快的服务器 + var fastestServer string + var fastestSpeed time.Duration = 2 * time.Second + + for result := range results { + if result.speed < fastestSpeed { + fastestSpeed = result.speed + fastestServer = result.server + } + } + + // 如果没有找到最快服务器(理论上不会发生),返回第一个服务器 + if fastestServer == "" { + fastestServer = servers[0] + } + + return fastestServer +} + +// calculateAvgResponseTime 计算平均响应时间 +func calculateAvgResponseTime(totalResponseTime int64, queries int64) float64 { + if queries <= 0 { + return 0 + } + + avg := float64(totalResponseTime) / float64(queries) + avg = float64(math.Round(avg)) + + // 限制平均响应时间的范围 + if avg > 60000 { + avg = 60000 + } + + return avg +} + +// updateStats 更新统计信息 +func (s *Server) updateStats(update func(*Stats)) { + s.statsMutex.Lock() + defer s.statsMutex.Unlock() + update(s.stats) +} + +// addQueryLog 添加查询日志 +func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) { + // 创建日志记录 + queryLog := QueryLog{ + Timestamp: time.Now(), + ClientIP: clientIP, + Domain: domain, + QueryType: queryType, + ResponseTime: responseTime, + Result: result, + BlockRule: blockRule, + BlockType: blockType, + FromCache: fromCache, + DNSSEC: dnssec, + EDNS: edns, + DNSServer: dnsServer, + DNSSECServer: dnssecServer, + Answers: answers, + ResponseCode: responseCode, + } + + // 使用新日志系统记录(如果已初始化) + if s.logManager != nil { + // 将 Answers 转换为 JSON 字符串 + answersJSON := "" + if len(queryLog.Answers) > 0 { + data, err := json.Marshal(queryLog.Answers) + if err == nil { + answersJSON = string(data) + } + } + + newLog := log.QueryLog{ + Timestamp: queryLog.Timestamp, + ClientIP: queryLog.ClientIP, + Domain: queryLog.Domain, + QueryType: queryLog.QueryType, + ResponseTime: queryLog.ResponseTime, + Result: queryLog.Result, + BlockRule: queryLog.BlockRule, + BlockType: queryLog.BlockType, + FromCache: queryLog.FromCache, + DNSSEC: queryLog.DNSSEC, + EDNS: queryLog.EDNS, + DNSServer: queryLog.DNSServer, + DNSSECServer: queryLog.DNSSECServer, + Answers: answersJSON, + ResponseCode: queryLog.ResponseCode, + } + err := s.logManager.Log(newLog) + if err != nil { + logger.Error("新日志系统记录失败", "domain", queryLog.Domain, "error", err) + } + } + + // 同时使用旧日志系统记录(兼容性) + // 发送到日志处理通道(阻塞式,确保日志不会丢失) +} + +// GetStartTime 获取服务器启动时间 + +// GetStats 获取DNS服务器统计信息 +func (s *Server) GetStats() *Stats { + s.statsMutex.Lock() + defer s.statsMutex.Unlock() + + // 复制查询类型统计 + queryTypesCopy := make(map[string]int64) + for k, v := range s.stats.QueryTypes { + queryTypesCopy[k] = v + } + + // 复制来源IP统计 + sourceIPsCopy := make(map[string]bool) + for ip := range s.stats.SourceIPs { + sourceIPsCopy[ip] = true + } + + // 返回统计信息的副本 + return &Stats{ + Queries: s.stats.Queries, + Blocked: s.stats.Blocked, + Allowed: s.stats.Allowed, + Errors: s.stats.Errors, + LastQuery: s.stats.LastQuery, + AvgResponseTime: s.stats.AvgResponseTime, + TotalResponseTime: s.stats.TotalResponseTime, + QueryTypes: queryTypesCopy, + SourceIPs: sourceIPsCopy, + CpuUsage: s.stats.CpuUsage, + DNSSECQueries: s.stats.DNSSECQueries, + DNSSECSuccess: s.stats.DNSSECSuccess, + DNSSECFailed: s.stats.DNSSECFailed, + DNSSECEnabled: s.stats.DNSSECEnabled, + } +} + +// GetQueryLogs 获取查询日志 +func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm, queryType string) []QueryLog { + // 优先使用新日志系统查询(如果已初始化) + if s.logManager != nil { + logger.Debug("使用新日志系统查询", "filter", resultFilter, "search", searchTerm, "queryType", queryType) + + // 转换排序字段名称以匹配数据库字段 + dbSortField := sortField + if sortField == "time" { + dbSortField = "timestamp" + } else if sortField == "clientIp" { + dbSortField = "client_ip" + } else if sortField == "responseTime" { + dbSortField = "response_time" + } else if sortField == "blockRule" { + dbSortField = "block_rule" + } + + filter := log.LogFilter{ + Result: resultFilter, + SearchTerm: searchTerm, + QueryType: queryType, + } + + page := log.PageParams{ + Limit: limit, + Offset: offset, + SortField: dbSortField, + SortDirection: sortDirection, + } + + logs, total, err := s.logManager.QueryLogs(filter, page) + logger.Debug("新日志系统查询结果", "logs", len(logs), "total", total, "error", err) + + if err == nil && len(logs) > 0 { + // 将新日志格式转换为旧格式 + result := make([]QueryLog, 0, len(logs)) + for _, newLog := range logs { + // 将 JSON 字符串解析为 DNSAnswer 数组 + var answers []DNSAnswer + if newLog.Answers != "" { + json.Unmarshal([]byte(newLog.Answers), &answers) + } + + oldLog := QueryLog{ + Timestamp: newLog.Timestamp, + ClientIP: newLog.ClientIP, + Domain: newLog.Domain, + QueryType: newLog.QueryType, + ResponseTime: newLog.ResponseTime, + Result: newLog.Result, + BlockRule: newLog.BlockRule, + BlockType: newLog.BlockType, + FromCache: newLog.FromCache, + DNSSEC: newLog.DNSSEC, + EDNS: newLog.EDNS, + DNSServer: newLog.DNSServer, + DNSSECServer: newLog.DNSSECServer, + Answers: answers, + ResponseCode: newLog.ResponseCode, + } + result = append(result, oldLog) + } + return result + } + // 如果新系统查询失败或没有数据,返回空列表 + logger.Debug("新日志系统查询失败或无数据", "error", err, "logs", len(logs)) + } + + // 返回空列表 + return []QueryLog{} +} + +// GetQueryLogsCount 获取查询日志总数 +func (s *Server) GetQueryLogsCount() int { + // 使用新日志系统获取总数 + if s.logManager != nil { + stats, err := s.logManager.GetStats(log.TimeRange{}) + if err == nil { + return int(stats.TotalQueries) + } + } + return 0 +} + +// GetQueryLogsCountWithFilter 获取带过滤条件的查询日志总数 +func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType string) int { + // 优先使用新日志系统查询(如果已初始化) + if s.logManager != nil { + filter := log.LogFilter{ + Result: resultFilter, + SearchTerm: searchTerm, + QueryType: queryType, + } + + page := log.PageParams{ + Limit: 1, // 只需要总数,获取 1 条数据即可 + Offset: 0, + SortField: "timestamp", + SortDirection: "desc", + } + + _, total, err := s.logManager.QueryLogs(filter, page) + if err == nil { + return int(total) + } + // 如果新系统查询失败,返回 0 + logger.Debug("新日志系统查询失败", "error", err) + } + return 0 +} +func (s *Server) GetQueryStats() map[string]interface{} { + s.statsMutex.Lock() + defer s.statsMutex.Unlock() + + // 计算统计数据 + return map[string]interface{}{ + "totalQueries": s.stats.Queries, + "blockedQueries": s.stats.Blocked, + "allowedQueries": s.stats.Allowed, + "errorQueries": s.stats.Errors, + "avgResponseTime": s.stats.AvgResponseTime, + "activeIPs": len(s.stats.SourceIPs), + } +} + +// GetTopBlockedDomains 获取TOP屏蔽域名列表 +func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain { + s.blockedDomainsMutex.RLock() + defer s.blockedDomainsMutex.RUnlock() + + // 计算30天前的时间戳 + thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix() + + // 转换为切片并过滤最近30天的数据 + domains := make([]BlockedDomain, 0, len(s.blockedDomains)) + for _, entry := range s.blockedDomains { + // 只包含最近30天的数据 + if entry.LastSeen >= thirtyDaysAgo { + domains = append(domains, *entry) + } + } + + // 按计数排序 + sort.Slice(domains, func(i, j int) bool { + return domains[i].Count > domains[j].Count + }) + + // 返回限制数量 + if len(domains) > limit { + return domains[:limit] + } + return domains +} + +// GetTopResolvedDomains 获取TOP解析域名 +func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain { + s.resolvedDomainsMutex.RLock() + defer s.resolvedDomainsMutex.RUnlock() + + // 计算30天前的时间戳 + thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix() + + // 转换为切片并过滤最近30天的数据 + domains := make([]BlockedDomain, 0, len(s.resolvedDomains)) + for _, entry := range s.resolvedDomains { + // 只包含最近30天的数据 + if entry.LastSeen >= thirtyDaysAgo { + domains = append(domains, *entry) + } + } + + // 按数量排序 + sort.Slice(domains, func(i, j int) bool { + return domains[i].Count > domains[j].Count + }) + + // 返回限制数量 + if len(domains) > limit { + return domains[:limit] + } + return domains +} + +// GetRecentBlockedDomains 获取最近屏蔽的域名列表 +func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain { + s.blockedDomainsMutex.RLock() + defer s.blockedDomainsMutex.RUnlock() + + // 转换为切片 + domains := make([]BlockedDomain, 0, len(s.blockedDomains)) + for _, entry := range s.blockedDomains { + domains = append(domains, *entry) + } + + // 按时间排序 + sort.Slice(domains, func(i, j int) bool { + return domains[i].LastSeen > domains[j].LastSeen + }) + + // 返回限制数量 + if len(domains) > limit { + return domains[:limit] + } + return domains +} + +// GetTopClients 获取TOP客户端列表 +func (s *Server) GetTopClients(limit int) []ClientStats { + s.clientStatsMutex.RLock() + defer s.clientStatsMutex.RUnlock() + + // 转换为切片 + clients := make([]ClientStats, 0, len(s.clientStats)) + for _, entry := range s.clientStats { + clients = append(clients, *entry) + } + + // 按请求次数排序 + sort.Slice(clients, func(i, j int) bool { + return clients[i].Count > clients[j].Count + }) + + // 返回限制数量 + if len(clients) > limit { + return clients[:limit] + } + return clients +} + +// GetHourlyStats 获取每小时统计数据 +func (s *Server) GetHourlyStats() map[string]int64 { + s.hourlyStatsMutex.RLock() + defer s.hourlyStatsMutex.RUnlock() + + // 返回副本 + result := make(map[string]int64) + for k, v := range s.hourlyStats { + result[k] = v + } + return result +} + +// GetDailyStats 获取每日统计数据 +func (s *Server) GetDailyStats() map[string]int64 { + s.dailyStatsMutex.RLock() + defer s.dailyStatsMutex.RUnlock() + + // 返回副本 + result := make(map[string]int64) + for k, v := range s.dailyStats { + result[k] = v + } + return result +} + +// GetMonthlyStats 获取每月统计数据 +func (s *Server) GetMonthlyStats() map[string]int64 { + s.monthlyStatsMutex.RLock() + defer s.monthlyStatsMutex.RUnlock() + + // 返回副本 + result := make(map[string]int64) + for k, v := range s.monthlyStats { + result[k] = v + } + return result +} + +// checkThreatDetection 检查DNS查询是否存在威胁 +func (s *Server) checkThreatDetection(sourceIP, domain, queryType string) { + if s.threatEngine == nil { + return + } + + // 调用威胁检测引擎检查查询 + alerts := s.threatEngine.CheckQuery(sourceIP, domain, queryType) + + // 处理检测到的威胁 + for _, alert := range alerts { + // 添加告警到告警管理器 + s.alertManager.AddAlert(alert) + } +} + +// GetAlerts 获取告警列表 +func (s *Server) GetAlerts(limit, offset int, level string) []*threat.ThreatAlert { + if s.alertManager == nil { + return []*threat.ThreatAlert{} + } + + return s.alertManager.GetAlerts(limit, offset, level) +} + +// GetAlertCount 获取告警数量 +func (s *Server) GetAlertCount(level string) int { + if s.alertManager == nil { + return 0 + } + + return s.alertManager.GetAlertCount(level) +} + +// ResolveAlert 解决告警 +func (s *Server) ResolveAlert(alertID, action string) bool { + if s.alertManager == nil { + return false + } + + return s.alertManager.ResolveAlert(alertID, action) +} + +// GetThreatDomains 获取所有威胁域名信息 +func (s *Server) GetThreatDomains() []*threat.ThreatInfo { + if s.dbManager == nil { + return []*threat.ThreatInfo{} + } + + return s.dbManager.GetAllThreatDomains() +} + +// AddThreatDomain 添加威胁域名 +func (s *Server) AddThreatDomain(threatType, name string, riskLevel int, domain string) error { + if s.dbManager == nil { + return nil + } + + return s.dbManager.AddThreatDomain(threatType, name, riskLevel, domain) +} + +// RemoveThreatDomain 删除威胁域名 +func (s *Server) RemoveThreatDomain(domain string) error { + if s.dbManager == nil { + return nil + } + + return s.dbManager.RemoveThreatDomain(domain) +} + +// isPrivateIP 检测IP地址是否为内网IP +func isPrivateIP(ip string) bool { + // 解析IP地址 + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false + } + + // 检查IPv4内网地址 + if ipv4 := parsedIP.To4(); ipv4 != nil { + // 10.0.0.0/8 + if ipv4[0] == 10 { + return true + } + // 172.16.0.0/12 + if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) { + return true + } + // 192.168.0.0/16 + if ipv4[0] == 192 && ipv4[1] == 168 { + return true + } + // 127.0.0.0/8 (localhost) + if ipv4[0] == 127 { + return true + } + // 169.254.0.0/16 (链路本地地址) + if ipv4[0] == 169 && ipv4[1] == 254 { + return true + } + return false + } + + // 检查IPv6内网地址 + // ::1/128 (localhost) + if parsedIP.IsLoopback() { + return true + } + // fc00::/7 (唯一本地地址) + if parsedIP[0]&0xfc == 0xfc { + return true + } + // fe80::/10 (链路本地地址) + if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 { + return true + } + return false +} + +// loadStatsData 从文件加载统计数据 +func (s *Server) loadStatsData() { + // 检查文件是否存在 + data, err := ioutil.ReadFile("data/stats.json") + if err != nil { + if !os.IsNotExist(err) { + logger.Error("读取统计数据文件失败", "error", err) + } + return + } + + var statsData StatsData + err = json.Unmarshal(data, &statsData) + if err != nil { + logger.Error("解析统计数据失败", "error", err) + return + } + + // 恢复统计数据 + s.statsMutex.Lock() + if statsData.Stats != nil { + // 只恢复有效数据,避免破坏统计关系 + s.stats.Queries += statsData.Stats.Queries + s.stats.Blocked += statsData.Stats.Blocked + s.stats.Allowed += statsData.Stats.Allowed + s.stats.Errors += statsData.Stats.Errors + s.stats.TotalResponseTime += statsData.Stats.TotalResponseTime + s.stats.DNSSECQueries += statsData.Stats.DNSSECQueries + s.stats.DNSSECSuccess += statsData.Stats.DNSSECSuccess + s.stats.DNSSECFailed += statsData.Stats.DNSSECFailed + + // 重新计算平均响应时间,确保一致性 + if s.stats.Queries > 0 { + s.stats.AvgResponseTime = float64(s.stats.TotalResponseTime) / float64(s.stats.Queries) + // 限制平均响应时间的范围,避免显示异常大的值 + if s.stats.AvgResponseTime > 60000 { + s.stats.AvgResponseTime = 60000 + } + } + + // 合并查询类型统计 + for k, v := range statsData.Stats.QueryTypes { + s.stats.QueryTypes[k] += v + } + + // 合并来源IP统计 + for ip := range statsData.Stats.SourceIPs { + s.stats.SourceIPs[ip] = true + } + + // 确保使用当前配置中的EnableDNSSEC值 + s.stats.DNSSECEnabled = s.config.EnableDNSSEC + } + s.statsMutex.Unlock() + + s.blockedDomainsMutex.Lock() + if statsData.BlockedDomains != nil { + s.blockedDomains = statsData.BlockedDomains + } + s.blockedDomainsMutex.Unlock() + + s.resolvedDomainsMutex.Lock() + if statsData.ResolvedDomains != nil { + s.resolvedDomains = statsData.ResolvedDomains + } + s.resolvedDomainsMutex.Unlock() + + s.hourlyStatsMutex.Lock() + if statsData.HourlyStats != nil { + s.hourlyStats = statsData.HourlyStats + } + s.hourlyStatsMutex.Unlock() + + s.dailyStatsMutex.Lock() + if statsData.DailyStats != nil { + s.dailyStats = statsData.DailyStats + } + s.dailyStatsMutex.Unlock() + + s.monthlyStatsMutex.Lock() + if statsData.MonthlyStats != nil { + s.monthlyStats = statsData.MonthlyStats + } + s.monthlyStatsMutex.Unlock() + + // 加载客户端统计数据 + s.clientStatsMutex.Lock() + if statsData.ClientStats != nil { + s.clientStats = statsData.ClientStats + } + s.clientStatsMutex.Unlock() + + logger.Info("统计数据加载成功") +} + +// processLogs 异步处理日志记录 +// saveStatsData 保存统计数据到文件 +func (s *Server) saveStatsData() { + // 获取绝对路径以避免工作目录问题 + statsFilePath, err := filepath.Abs("data/stats.json") + if err != nil { + logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err) + return + } + + // 创建数据目录 + statsDir := filepath.Dir(statsFilePath) + err = os.MkdirAll(statsDir, 0755) + if err != nil { + logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err) + return + } + + // 收集所有统计数据 + statsData := &StatsData{ + Stats: s.GetStats(), + LastSaved: time.Now(), + } + + // 复制域名数据 + s.blockedDomainsMutex.RLock() + statsData.BlockedDomains = make(map[string]*BlockedDomain) + for k, v := range s.blockedDomains { + statsData.BlockedDomains[k] = v + } + s.blockedDomainsMutex.RUnlock() + + s.resolvedDomainsMutex.RLock() + statsData.ResolvedDomains = make(map[string]*BlockedDomain) + for k, v := range s.resolvedDomains { + statsData.ResolvedDomains[k] = v + } + s.resolvedDomainsMutex.RUnlock() + + s.hourlyStatsMutex.RLock() + statsData.HourlyStats = make(map[string]int64) + for k, v := range s.hourlyStats { + statsData.HourlyStats[k] = v + } + s.hourlyStatsMutex.RUnlock() + + s.dailyStatsMutex.RLock() + statsData.DailyStats = make(map[string]int64) + for k, v := range s.dailyStats { + statsData.DailyStats[k] = v + } + s.dailyStatsMutex.RUnlock() + + s.monthlyStatsMutex.RLock() + statsData.MonthlyStats = make(map[string]int64) + for k, v := range s.monthlyStats { + statsData.MonthlyStats[k] = v + } + s.monthlyStatsMutex.RUnlock() + + // 复制客户端统计数据 + s.clientStatsMutex.RLock() + statsData.ClientStats = make(map[string]*ClientStats) + for k, v := range s.clientStats { + statsData.ClientStats[k] = v + } + s.clientStatsMutex.RUnlock() + + // 序列化数据 + jsonData, err := json.MarshalIndent(statsData, "", " ") + if err != nil { + logger.Error("序列化统计数据失败", "error", err) + return + } + + // 写入文件 + err = os.WriteFile(statsFilePath, jsonData, 0644) + if err != nil { + logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err) + return + } + + logger.Info("统计数据保存成功", "file", statsFilePath) +} + +// startCpuUsageMonitor 启动 CPU 使用率监控 +func (s *Server) startCpuUsageMonitor() { + ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率 + defer ticker.Stop() + + // 初始化 + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // 存储上一次的CPU时间统计 + var prevIdle, prevTotal uint64 + + for { + select { + case <-ticker.C: + // 获取真实的系统级CPU使用率 + cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal) + if err != nil { + // 如果获取失败,使用默认值 + cpuUsage = 0.0 + logger.Error("获取系统CPU使用率失败", "error", err) + } + + s.updateStats(func(stats *Stats) { + stats.CpuUsage = cpuUsage + }) + case <-s.ctx.Done(): + return + } + } +} + +// getSystemCpuUsage 获取系统CPU使用率 +func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) { + // 读取/proc/stat文件获取CPU统计信息 + file, err := os.Open("/proc/stat") + if err != nil { + return 0, err + } + defer file.Close() + + var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64 + _, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d", + &cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal) + if err != nil { + return 0, err + } + + // 计算总的CPU时间 + total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal + idle := cpuIdle + cpuIowait + + // 第一次调用时,只初始化值,不计算使用率 + if *prevTotal == 0 || *prevIdle == 0 { + *prevIdle = idle + *prevTotal = total + return 0, nil + } + + // 计算CPU使用率 + idleDelta := idle - *prevIdle + totalDelta := total - *prevTotal + utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100 + + // 更新上一次的值 + *prevIdle = idle + *prevTotal = total + + return utilization, nil +} + +// startAutoSave 启动自动保存功能 +func (s *Server) startAutoSave() { + if s.config.SaveInterval <= 0 { + return + } + + // 初始化定时器 +} + +// GetArchiveQueryEngine 获取归档查询引擎 +func (s *Server) GetArchiveQueryEngine() *log.ArchiveQueryEngine { + return s.archiveQueryEngine +} + +// GetArchiveManager 获取归档管理器 +func (s *Server) GetArchiveManager() *log.ArchiveManager { + return s.archiveManager +} diff --git a/dns/server.go.rej b/dns/server.go.rej new file mode 100644 index 0000000..cfafb6c --- /dev/null +++ b/dns/server.go.rej @@ -0,0 +1,38 @@ +--- dns/server.go ++++ dns/server.go +@@ -605,6 +605,8 @@ func (s *Server) checkRequestConditions(w dns.ResponseWriter, r *dns.Msg, startT + if r.RecursionDesired == false { + se := new(dns.Msg) + se.SetReply(r) ++// 设置递归可用标志 ++response.RecursionAvailable = true + 不再硬编码 RecursionAvailable,使用默认值或上游返回的值 + se.SetRcode(r, dns.RcodeRefused) + se) +@@ -1010,6 +1012,8 @@ func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string + func handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) { + response := new(dns.Msg) + response.SetReply(r) ++// 设置递归可用标志(因为我们的 DNS 服务器支持递归查询) ++response.RecursionAvailable = true + // 不再硬编码 RecursionAvailable,使用默认值或上游返回的值 + + if len(r.Question) > 0 { +@@ -1051,6 +1055,8 @@ func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain + + response := new(dns.Msg) + response.SetReply(r) ++// 设置递归可用标志 ++response.RecursionAvailable = true + + if len(r.Question) > 0 { + := r.Question[0] +@@ -1082,6 +1088,8 @@ func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain + + response := new(dns.Msg) + response.SetReply(r) ++// 设置递归可用标志 ++response.RecursionAvailable = true + // 不再硬编码 RecursionAvailable,使用默认值或上游返回的值 + + // 获取屏蔽方法配置