package dns import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "math/rand" "net" "os" "path/filepath" "runtime" "sort" "strings" "sync" "sync/atomic" "time" "dns-server/config" "dns-server/gfw" "dns-server/logger" "dns-server/shield" "github.com/miekg/dns" ) // 确保DNS服务器地址包含端口号,默认添加53端口 func normalizeDNSServerAddress(address string) string { // 检查地址是否已经包含端口号 if _, _, err := net.SplitHostPort(address); err != nil { // 如果没有端口号,添加默认的53端口 return net.JoinHostPort(address, "53") } // 已经有端口号,直接返回 return address } // BlockedDomain 屏蔽域名统计 type BlockedDomain struct { Domain string Count int64 LastSeen int64 DNSSEC bool // 是否使用了DNSSEC } // ClientStats 客户端统计 type ClientStats struct { IP string Count int64 LastSeen int64 } // DNSAnswer DNS解析记录 type DNSAnswer struct { Type string `json:"type"` // 记录类型 Value string `json:"value"` // 记录值 TTL uint32 `json:"ttl"` // 生存时间 } // QueryLog 查询日志记录 type QueryLog struct { Timestamp time.Time `json:"timestamp"` // 查询时间 ClientIP string `json:"clientIP"` // 客户端IP Location string `json:"location"` // IP地理位置(国家 城市) Domain string `json:"domain"` // 查询域名 QueryType string `json:"queryType"` // 查询类型 ResponseTime int64 `json:"responseTime"` // 响应时间(ms) Result string `json:"result"` // 查询结果(allowed, blocked, error) BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽) BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽) FromCache bool `json:"fromCache"` // 是否来自缓存 DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC EDNS bool `json:"edns"` // 是否使用了EDNS DNSServer string `json:"dnsServer"` // 使用的DNS服务器 DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器 Answers []DNSAnswer `json:"answers"` // 解析记录 ResponseCode int `json:"responseCode"` // DNS响应代码 } // StatsData 用于持久化的统计数据结构 type StatsData struct { Stats *Stats `json:"stats"` BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"` ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"` ClientStats map[string]*ClientStats `json:"clientStats"` HourlyStats map[string]int64 `json:"hourlyStats"` DailyStats map[string]int64 `json:"dailyStats"` MonthlyStats map[string]int64 `json:"monthlyStats"` LastSaved time.Time `json:"lastSaved"` } // ServerStats 服务器统计信息 type ServerStats struct { SuccessCount int64 // 成功查询次数 FailureCount int64 // 失败查询次数 LastResponse time.Time // 最后响应时间 ResponseTime time.Duration // 平均响应时间 ConnectionSpeed time.Duration // TCP连接速度 } // Server DNS服务器 type Server struct { config *config.DNSConfig shieldConfig *config.ShieldConfig shieldManager *shield.ShieldManager gfwConfig *config.GFWListConfig gfwManager *gfw.GFWListManager server *dns.Server tcpServer *dns.Server resolver *dns.Client ctx context.Context cancel context.CancelFunc statsMutex sync.Mutex stats *Stats blockedDomainsMutex sync.RWMutex blockedDomains map[string]*BlockedDomain resolvedDomainsMutex sync.RWMutex resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名 clientStatsMutex sync.RWMutex clientStats map[string]*ClientStats // 用于记录客户端统计 hourlyStatsMutex sync.RWMutex hourlyStats map[string]int64 // 按小时统计屏蔽数量 dailyStatsMutex sync.RWMutex dailyStats map[string]int64 // 按天统计屏蔽数量 monthlyStatsMutex sync.RWMutex monthlyStats map[string]int64 // 按月统计屏蔽数量 queryLogsMutex sync.RWMutex queryLogs []QueryLog // 查询日志列表 maxQueryLogs int // 最大保存日志数量 logChannel chan QueryLog // 日志处理通道 saveTicker *time.Ticker // 用于定时保存数据 startTime time.Time // 服务器启动时间 saveDone chan struct{} // 用于通知保存协程停止 stopped bool // 服务器是否已经停止 stoppedMutex sync.Mutex // 保护stopped标志的互斥锁 // DNS查询缓存 DnsCache *DNSCache // DNS响应缓存 // 域名DNSSEC状态映射表 domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射 domainDNSSECStatusMutex sync.RWMutex // 保护域名DNSSEC状态映射的互斥锁 // 上游服务器状态跟踪 serverStats map[string]*ServerStats // 服务器地址到状态的映射 serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁 // DNSSEC专用服务器映射,用于快速查找 dnssecServerMap map[string]bool // DNSSEC专用服务器地址到布尔值的映射 // DNS客户端实例池,用于并行查询 clientPool sync.Pool // 存储*dns.Client实例 } // Stats DNS服务器统计信息 type Stats struct { Queries int64 Blocked int64 Allowed int64 Errors int64 LastQuery time.Time AvgResponseTime float64 // 平均响应时间(ms) TotalResponseTime int64 // 总响应时间 QueryTypes map[string]int64 // 查询类型统计 SourceIPs map[string]bool // 活跃来源IP CpuUsage float64 // CPU使用率(%) DNSSECQueries int64 // DNSSEC查询总数 DNSSECSuccess int64 // DNSSEC验证成功数 DNSSECFailed int64 // DNSSEC验证失败数 DNSSECEnabled bool // 是否启用了DNSSEC } // NewServer 创建DNS服务器实例 func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager, gfwConfig *config.GFWListConfig, gfwManager *gfw.GFWListManager) *Server { ctx, cancel := context.WithCancel(context.Background()) // 从配置中读取DNS缓存TTL值(分钟) cacheTTL := time.Duration(config.CacheTTL) * time.Minute server := &Server{ config: config, shieldConfig: shieldConfig, shieldManager: shieldManager, gfwConfig: gfwConfig, gfwManager: gfwManager, resolver: &dns.Client{ Net: "udp", UDPSize: 4096, // 增加UDP缓冲区大小,支持更大的DNSSEC响应 }, ctx: ctx, cancel: cancel, startTime: time.Now(), // 记录服务器启动时间 stats: &Stats{ Queries: 0, Blocked: 0, Allowed: 0, Errors: 0, AvgResponseTime: 0, TotalResponseTime: 0, QueryTypes: make(map[string]int64), SourceIPs: make(map[string]bool), CpuUsage: 0, DNSSECQueries: 0, DNSSECSuccess: 0, DNSSECFailed: 0, DNSSECEnabled: config.EnableDNSSEC, }, blockedDomains: make(map[string]*BlockedDomain), resolvedDomains: make(map[string]*BlockedDomain), clientStats: make(map[string]*ClientStats), hourlyStats: make(map[string]int64), dailyStats: make(map[string]int64), monthlyStats: make(map[string]int64), queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000 maxQueryLogs: 10000, // 最大保存10000条日志 logChannel: make(chan QueryLog, 1000), // 日志处理通道,缓冲区大小1000 saveDone: make(chan struct{}), stopped: false, // 初始化为未停止状态 // DNS查询缓存初始化 DnsCache: NewDNSCache(cacheTTL), // 初始化域名DNSSEC状态映射表 domainDNSSECStatus: make(map[string]bool), // 初始化服务器状态跟踪 serverStats: make(map[string]*ServerStats), // 初始化DNSSEC专用服务器映射 dnssecServerMap: make(map[string]bool), // 初始化DNS客户端实例池 clientPool: sync.Pool{ New: func() interface{} { return &dns.Client{ Net: "udp", UDPSize: 4096, Timeout: 5 * time.Second, // 默认超时时间,会在使用时覆盖 } }, }, } // 加载已保存的统计数据 server.loadStatsData() return server } // Start 启动DNS服务器 func (s *Server) Start() error { // 重新初始化上下文和取消函数 ctx, cancel := context.WithCancel(context.Background()) s.ctx = ctx s.cancel = cancel // 重新初始化saveDone通道 s.saveDone = make(chan struct{}) // 重置stopped标志 s.stoppedMutex.Lock() s.stopped = false s.stoppedMutex.Unlock() // 更新服务器启动时间 s.startTime = time.Now() s.server = &dns.Server{ Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), Net: "udp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 保存TCP服务器实例,以便在Stop方法中关闭 s.tcpServer = &dns.Server{ Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), Net: "tcp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 启动CPU使用率监控 go s.startCpuUsageMonitor() // 启动自动保存功能 go s.startAutoSave() // 更新DNSSEC专用服务器映射 s.updateDNSSECServerMap() // 启动日志处理协程 go s.processLogs() // 启动统计数据定期重置功能(每24小时) go func() { ticker := time.NewTicker(24 * time.Hour) defer ticker.Stop() for { select { case <-ticker.C: s.resetStats() case <-s.ctx.Done(): return } } }() // 启动UDP服务 go func() { logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port)) if err := s.server.ListenAndServe(); err != nil { logger.Error("DNS UDP服务器启动失败", "error", err) s.cancel() } }() // 启动TCP服务 go func() { logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port)) if err := s.tcpServer.ListenAndServe(); err != nil { logger.Error("DNS TCP服务器启动失败", "error", err) s.cancel() } }() // 等待停止信号 <-s.ctx.Done() return nil } // resetStats 重置统计数据 func (s *Server) resetStats() { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 只重置累计值,保留配置相关值 s.stats.TotalResponseTime = 0 s.stats.AvgResponseTime = 0 s.stats.Queries = 0 s.stats.Blocked = 0 s.stats.Allowed = 0 s.stats.Errors = 0 s.stats.DNSSECQueries = 0 s.stats.DNSSECSuccess = 0 s.stats.DNSSECFailed = 0 s.stats.QueryTypes = make(map[string]int64) s.stats.SourceIPs = make(map[string]bool) logger.Info("统计数据已重置") } // Stop 停止DNS服务器 func (s *Server) Stop() { // 检查服务器是否已经停止 s.stoppedMutex.Lock() if s.stopped { s.stoppedMutex.Unlock() return // 服务器已经停止,直接返回 } // 标记服务器为已停止状态 s.stopped = true s.stoppedMutex.Unlock() // 发送停止信号给保存协程 close(s.saveDone) // 最后保存一次数据 s.saveStatsData() // 停止服务器 s.cancel() if s.server != nil { s.server.Shutdown() } if s.tcpServer != nil { s.tcpServer.Shutdown() } logger.Info("DNS服务器已停止") } // handleDNSRequest 处理DNS请求 func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { startTime := time.Now() // 获取来源IP sourceIP := w.RemoteAddr().String() // 提取IP地址部分,去掉端口 if strings.HasPrefix(sourceIP, "[") { // IPv6地址格式: [::1]:53 if idx := strings.Index(sourceIP, "]"); idx >= 0 { sourceIP = sourceIP[1:idx] // 去掉方括号 } } else { // IPv4地址格式: 127.0.0.1:53 if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 { sourceIP = sourceIP[:idx] } } // 更新来源IP统计和Queries计数器 s.updateStats(func(stats *Stats) { stats.Queries++ stats.LastQuery = time.Now() stats.SourceIPs[sourceIP] = true }) // 更新客户端统计 s.updateClientStats(sourceIP) // 获取查询域名和类型 var domain string var queryType string var qType uint16 if len(r.Question) > 0 { domain = r.Question[0].Name // 移除末尾的点 if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } // 获取查询类型 queryType = dns.TypeToString[r.Question[0].Qtype] qType = r.Question[0].Qtype // 更新查询类型统计 s.updateStats(func(stats *Stats) { stats.QueryTypes[queryType]++ }) // 检查是否是AAAA记录查询且IPv6解析已禁用 if qType == dns.TypeAAAA && !s.config.EnableIPv6 { // 返回NXDOMAIN响应(域名不存在) response := new(dns.Msg) response.SetReply(r) response.SetRcode(r, dns.RcodeNameError) w.WriteMsg(response) // 更新统计信息 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime // 添加防御性编程,确保Queries大于0 if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) // 限制平均响应时间的范围,避免显示异常大的值 if stats.AvgResponseTime > 60000 { stats.AvgResponseTime = 60000 } } }) // 添加查询日志 s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeNameError) logger.Debug("IPv6解析已禁用,拒绝AAAA记录查询", "domain", domain) return } } logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr()) // 只处理递归查询 if r.RecursionDesired == false { response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 response.SetRcode(r, dns.RcodeRefused) w.WriteMsg(response) // 计算实际响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) } }) // 添加查询日志 s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused) return } // 检查hosts文件是否有匹配 if ip, exists := s.shieldManager.GetHostsIP(domain); exists { s.handleHostsResponse(w, r, ip) // 计算实际响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) } }) // 该方法内部未直接调用addQueryLog,而是在handleDNSRequest中处理 return } // 检查是否为GFWList域名(仅当GFWList功能启用时) if s.gfwConfig.Enabled && s.gfwManager != nil && s.gfwManager.IsMatch(domain) { s.handleGFWListResponse(w, r, domain) // 计算响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) } }) // 添加查询日志 - GFWList域名 gfwAnswers := []DNSAnswer{} s.addQueryLog(sourceIP, domain, queryType, responseTime, "gfwlist", "", "", false, false, true, "GFWList", "无", gfwAnswers, dns.RcodeSuccess) return } // 检查是否被屏蔽 if s.shieldManager.IsBlocked(domain) { // 获取屏蔽详情 blockDetails := s.shieldManager.CheckDomainBlockDetails(domain) blockRule, _ := blockDetails["blockRule"].(string) blockType, _ := blockDetails["blockRuleType"].(string) s.handleBlockedResponse(w, r, domain) // 计算响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) } }) // 添加查询日志 - 被屏蔽域名 blockedAnswers := []DNSAnswer{} // 根据屏蔽方法确定响应代码 blockedRcode := dns.RcodeNameError // 默认NXDOMAIN if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" { blockedRcode = dns.RcodeRefused } else if blockMethod == "emptyIP" || blockMethod == "customIP" { blockedRcode = dns.RcodeSuccess } s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode) return } // 检查缓存中是否有响应(优先查找带DNSSEC的缓存项) var cachedResponse *dns.Msg var found bool var cachedDNSSEC bool // 1. 首先检查是否有普通缓存项 if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, qType); tempFound { cachedResponse = tempResponse found = tempFound cachedDNSSEC = s.hasDNSSECRecords(tempResponse) } // 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项, // 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录 // (这里可以进一步优化,比如在缓存中标记DNSSEC状态,快速查找) if s.config.EnableDNSSEC && !cachedDNSSEC { // 目前的缓存实现不支持按DNSSEC状态查找,所以这里暂时跳过 // 后续可以考虑改进缓存实现,添加DNSSEC状态标记 } if found { // 缓存命中,直接返回缓存的响应 cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题 cachedResponseCopy.Id = r.Id // 更新ID以匹配请求 cachedResponseCopy.Compress = true // 如果客户端请求包含EDNS记录,确保响应也包含EDNS if opt := r.IsEdns0(); opt != nil { // 检查响应是否已经包含EDNS记录 if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil { // 添加EDNS记录,使用客户端的UDP缓冲区大小 cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } else { // 确保响应的UDP缓冲区大小不超过客户端请求的大小 if respOpt.UDPSize() > opt.UDPSize() { // 移除现有的EDNS记录 for i := range cachedResponseCopy.Extra { if cachedResponseCopy.Extra[i] == respOpt { cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...) break } } // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } } } w.WriteMsg(cachedResponseCopy) // 计算响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) } }) // 如果缓存响应包含DNSSEC记录,更新DNSSEC查询计数 if cachedDNSSEC { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ // 缓存响应视为DNSSEC成功 stats.DNSSECSuccess++ }) } // 从缓存响应中提取解析记录 cachedAnswers := []DNSAnswer{} if cachedResponse != nil { for _, rr := range cachedResponse.Answer { cachedAnswers = append(cachedAnswers, DNSAnswer{ Type: dns.TypeToString[rr.Header().Rrtype], Value: rr.String(), TTL: rr.Header().Ttl, }) } } // 添加查询日志 - 标记为缓存 // 从缓存响应中获取响应代码 cacheRcode := dns.RcodeSuccess // 默认成功 if cachedResponse != nil { cacheRcode = cachedResponse.Rcode } s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode) logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC) return } // 缓存未命中,处理DNS请求 var response *dns.Msg var rtt time.Duration var queryAttempts []string var dnsServer string var dnssecServer string // 直接查询原始域名 queryAttempts = append(queryAttempts, domain) response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, domain) if response != nil { // 如果客户端请求包含EDNS记录,确保响应也包含EDNS if opt := r.IsEdns0(); opt != nil { // 检查响应是否已经包含EDNS记录 if respOpt := response.IsEdns0(); respOpt == nil { // 添加EDNS记录,使用客户端的UDP缓冲区大小 response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } else { // 确保响应的UDP缓冲区大小不超过客户端请求的大小 if respOpt.UDPSize() > opt.UDPSize() { // 移除现有的EDNS记录 for i := range response.Extra { if response.Extra[i] == respOpt { response.Extra = append(response.Extra[:i], response.Extra[i+1:]...) break } } // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } } } // 写入响应给客户端 w.WriteMsg(response) } // 使用上游服务器的实际响应时间(转换为毫秒) responseTime := int64(rtt.Milliseconds()) // 如果rtt为0(查询失败),则使用本地计算的时间 if responseTime == 0 { responseTime = time.Since(startTime).Milliseconds() } // 添加合理性检查,避免异常大的响应时间影响统计 if responseTime > 60000 { // 超过60秒的响应时间视为异常 responseTime = 60000 } s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime // 添加防御性编程,确保Queries大于0 if stats.Queries > 0 { // 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整 avg := float64(stats.TotalResponseTime) / float64(stats.Queries) stats.AvgResponseTime = float64(math.Round(avg)) // 限制平均响应时间的范围,避免显示异常大的值 if stats.AvgResponseTime > 60000 { stats.AvgResponseTime = 60000 } } }) // 检查响应是否包含DNSSEC记录并验证结果 responseDNSSEC := false if response != nil { // 使用hasDNSSECRecords函数检查是否包含DNSSEC记录 responseDNSSEC = s.hasDNSSECRecords(response) // 检查AD标志,确认DNSSEC验证是否成功 if response.AuthenticatedData { responseDNSSEC = true } // 更新域名的DNSSEC状态 if responseDNSSEC { s.updateDomainDNSSECStatus(domain, true) } } // 如果响应成功,缓存结果(增强版缓存存储) if response != nil && response.Rcode == dns.RcodeSuccess { // 创建响应副本以避免后续修改影响缓存 responseCopy := response.Copy() // 设置合理的TTL,不超过默认的30分钟 defaultCacheTTL := 30 * time.Minute s.DnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL) logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC) } // 从响应中提取解析记录 responseAnswers := []DNSAnswer{} if response != nil { for _, rr := range response.Answer { responseAnswers = append(responseAnswers, DNSAnswer{ Type: dns.TypeToString[rr.Header().Rrtype], Value: rr.String(), TTL: rr.Header().Ttl, }) } } // 添加查询日志 - 标记为实时 // 从响应中获取响应代码 realRcode := dns.RcodeSuccess // 默认成功 if response != nil { realRcode = response.Rcode } s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode) } // handleHostsResponse 处理hosts文件匹配的响应 func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) { response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 if len(r.Question) > 0 { q := r.Question[0] answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(ip) response.Answer = append(response.Answer, answer) } // 记录解析域名统计 domain := "" if len(r.Question) > 0 { domain = r.Question[0].Name if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } s.updateResolvedDomainStats(domain) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Allowed++ }) } // handleGFWListResponse 处理GFWList域名响应 func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { logger.Info("GFWList域名解析", "domain", domain, "client", w.RemoteAddr(), "ip", s.gfwConfig.IP) // 更新解析域名统计 s.updateResolvedDomainStats(domain) response := new(dns.Msg) response.SetReply(r) if len(r.Question) > 0 { q := r.Question[0] answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(s.gfwConfig.IP) response.Answer = append(response.Answer, answer) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Allowed++ }) } // handleBlockedResponse 处理被屏蔽的域名响应 func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr()) // 更新被屏蔽域名统计 s.updateBlockedDomainStats(domain) response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 // 获取屏蔽方法配置 blockMethod := "NXDOMAIN" // 默认值 customBlockIP := "" // 默认值 // 从Server结构体的shieldConfig字段获取配置 if s.shieldConfig != nil { blockMethod = s.shieldConfig.BlockMethod customBlockIP = s.shieldConfig.CustomBlockIP } // 根据屏蔽方法返回不同的响应 switch blockMethod { case "refused": // 返回拒绝查询响应 response.SetRcode(r, dns.RcodeRefused) case "emptyIP": // 返回空IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP("0.0.0.0") // 空IP response.Answer = append(response.Answer, answer) } case "customIP": // 返回自定义IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } // 使用自定义屏蔽IP if customBlockIP != "" { answer.A = net.ParseIP(customBlockIP) } else { // 如果没有配置,使用0.0.0.0 answer.A = net.ParseIP("0.0.0.0") } response.Answer = append(response.Answer, answer) } case "NXDOMAIN", "": fallthrough // 默认使用NXDOMAIN default: // 返回NXDOMAIN响应(域名不存在) response.SetRcode(r, dns.RcodeNameError) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Blocked++ }) } // forwardDNSRequest 转发DNS请求到上游服务器 // serverResponse 用于存储服务器响应的结构体 type serverResponse struct { response *dns.Msg rtt time.Duration server string error error } // recordKey 用于唯一标识DNS记录的结构体 type recordKey struct { name string rtype uint16 class uint16 data string } // getRecordKey 获取DNS记录的唯一标识 func getRecordKey(rr dns.RR) recordKey { // 对于同一域名的同一类型记录,只保留一个,选择最长TTL // 所以对于A、AAAA、CNAME等记录,只使用name、rtype、class作为键 // 对于MX记录,还需要考虑Preference字段 // 对于TXT记录,需要考虑实际文本内容 // 对于NS记录,需要考虑目标服务器 switch rr.Header().Rrtype { case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypePTR: // 对于A、AAAA、CNAME、PTR记录,同一域名只保留一个 return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: "", } case dns.TypeMX: // 对于MX记录,同一域名的同一Preference只保留一个 if mx, ok := rr.(*dns.MX); ok { return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: fmt.Sprintf("%d", mx.Preference), } } case dns.TypeTXT: // 对于TXT记录,需要考虑实际文本内容 if txt, ok := rr.(*dns.TXT); ok { return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: strings.Join(txt.Txt, " "), } } case dns.TypeNS: // 对于NS记录,需要考虑目标服务器 if ns, ok := rr.(*dns.NS); ok { return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: ns.Ns, } } case dns.TypeSOA: // 对于SOA记录,同一域名只保留一个 return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: "", } } // 对于其他类型,使用原始rr.String(),但移除TTL部分 parts := strings.Split(rr.String(), " ") if len(parts) >= 5 { // 跳过TTL字段(第3个字段) data := strings.Join(append(parts[:2], parts[3:]...), " ") return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: data, } } return recordKey{ name: rr.Header().Name, rtype: rr.Header().Rrtype, class: rr.Header().Class, data: rr.String(), } } // mergeResponses 合并多个DNS响应 func mergeResponses(responses []*dns.Msg) *dns.Msg { if len(responses) == 0 { return nil } // 如果只有一个响应,直接返回,避免不必要的合并操作 if len(responses) == 1 { return responses[0].Copy() } // 使用第一个响应作为基础 mergedResponse := responses[0].Copy() mergedResponse.Answer = []dns.RR{} mergedResponse.Ns = []dns.RR{} mergedResponse.Extra = []dns.RR{} // 重置Rcode为成功,除非所有响应都是NXDOMAIN mergedResponse.Rcode = dns.RcodeSuccess // 检查是否所有响应都是NXDOMAIN allNXDOMAIN := true // 收集所有成功响应的记录 for _, resp := range responses { if resp == nil { continue } // 如果有任何响应是成功的,就不是allNXDOMAIN if resp.Rcode == dns.RcodeSuccess { allNXDOMAIN = false } } // 如果所有响应都是NXDOMAIN,设置合并响应为NXDOMAIN if allNXDOMAIN { mergedResponse.Rcode = dns.RcodeNameError } // 使用map存储唯一记录,选择最长TTL // 预分配map容量,减少扩容开销 answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses)) nsMap := make(map[recordKey]dns.RR, len(responses[0].Ns)*len(responses)) extraMap := make(map[recordKey]dns.RR, len(responses[0].Extra)*len(responses)) for _, resp := range responses { if resp == nil { continue } // 只合并与最终Rcode匹配的响应记录 if (mergedResponse.Rcode == dns.RcodeSuccess && resp.Rcode == dns.RcodeSuccess) || (mergedResponse.Rcode == dns.RcodeNameError && resp.Rcode == dns.RcodeNameError) { // 合并Answer部分 for _, rr := range resp.Answer { key := getRecordKey(rr) if existing, exists := answerMap[key]; exists { // 如果存在相同记录,选择TTL更长的 if rr.Header().Ttl > existing.Header().Ttl { answerMap[key] = rr } } else { answerMap[key] = rr } } // 合并Ns部分 for _, rr := range resp.Ns { key := getRecordKey(rr) if existing, exists := nsMap[key]; exists { // 如果存在相同记录,选择TTL更长的 if rr.Header().Ttl > existing.Header().Ttl { nsMap[key] = rr } } else { nsMap[key] = rr } } // 合并Extra部分 for _, rr := range resp.Extra { // 跳过OPT记录,避免重复 if rr.Header().Rrtype == dns.TypeOPT { continue } key := getRecordKey(rr) if existing, exists := extraMap[key]; exists { // 如果存在相同记录,选择TTL更长的 if rr.Header().Ttl > existing.Header().Ttl { extraMap[key] = rr } } else { extraMap[key] = rr } } } } // 预分配切片容量,减少扩容开销 mergedResponse.Answer = make([]dns.RR, 0, len(answerMap)) mergedResponse.Ns = make([]dns.RR, 0, len(nsMap)) mergedResponse.Extra = make([]dns.RR, 0, len(extraMap)) // 将map转换回切片 for _, rr := range answerMap { mergedResponse.Answer = append(mergedResponse.Answer, rr) } for _, rr := range nsMap { mergedResponse.Ns = append(mergedResponse.Ns, rr) } for _, rr := range extraMap { mergedResponse.Extra = append(mergedResponse.Extra, rr) } return mergedResponse } // updateDNSSECServerMap 更新DNSSEC专用服务器映射,用于快速查找 func (s *Server) updateDNSSECServerMap() { // 清空现有映射 for k := range s.dnssecServerMap { delete(s.dnssecServerMap, k) } // 添加所有DNSSEC专用服务器到映射 for _, server := range s.config.DNSSECUpstreamDNS { s.dnssecServerMap[server] = true } } // forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应 func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) { // 始终支持EDNS var udpSize uint16 = 4096 var doFlag bool = s.config.EnableDNSSEC // 检查域名是否匹配不验证DNSSEC的模式 noDNSSEC := false for _, pattern := range s.config.NoDNSSECDomains { if strings.Contains(domain, pattern) { noDNSSEC = true doFlag = false logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern) break } } // 检查客户端请求是否包含EDNS记录 if opt := r.IsEdns0(); opt != nil { // 保留客户端的UDP缓冲区大小 udpSize = opt.UDPSize() // 移除现有的EDNS记录,以便重新添加 for i := range r.Extra { if r.Extra[i] == opt { r.Extra = append(r.Extra[:i], r.Extra[i+1:]...) break } } } // 添加EDNS记录,设置适当的UDPSize和DO标志 r.SetEdns0(udpSize, doFlag) // DNSSEC专用服务器列表,从配置中获取 dnssecServers := s.config.DNSSECUpstreamDNS // 选择合适的上游DNS服务器列表 // 1. 首先检查是否有域名特定的DNS服务器配置 var selectedUpstreamDNS []string var domainMatched bool for matchStr, dnsServers := range s.config.DomainSpecificDNS { if strings.Contains(domain, matchStr) { selectedUpstreamDNS = dnsServers domainMatched = true logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers) break } } // 2. 如果没有匹配的域名特定配置 if !domainMatched { // 创建一个新的切片来存储最终的上游服务器列表 var finalUpstreamDNS []string // 首先添加用户配置的上游DNS服务器 finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...) logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS) // 如果启用了DNSSEC且有配置DNSSEC专用服务器,并且域名不匹配NoDNSSECDomains,则将DNSSEC专用服务器添加到列表中 if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC { // 合并DNSSEC专用服务器到上游服务器列表,避免重复,并确保包含端口号 for _, dnssecServer := range s.config.DNSSECUpstreamDNS { hasDuplicate := false // 确保DNSSEC服务器地址包含端口号 normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer) for _, upstream := range finalUpstreamDNS { if upstream == normalizedDnssecServer { hasDuplicate = true break } } if !hasDuplicate { finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer) } } logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS) } // 使用最终合并后的服务器列表 selectedUpstreamDNS = finalUpstreamDNS } // 1. 首先尝试所有配置的上游DNS服务器 var bestResponse *dns.Msg var bestRtt time.Duration var hasBestResponse bool var hasDNSSECResponse bool var backupResponse *dns.Msg var backupRtt time.Duration var hasBackup bool var usedDNSServer string var usedDNSSECServer string // 使用配置中的超时时间 defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond // 根据查询模式处理请求 switch s.config.QueryMode { case "parallel": // 并行请求模式 - 收集所有响应并合并 responses := make(chan serverResponse, len(selectedUpstreamDNS)) var wg sync.WaitGroup // 向所有上游服务器并行发送请求,每个请求带有超时 for _, upstream := range selectedUpstreamDNS { wg.Add(1) go func(server string) { defer wg.Done() // 从池中获取客户端实例 client := s.clientPool.Get().(*dns.Client) // 设置客户端参数 client.Net = s.resolver.Net client.UDPSize = s.resolver.UDPSize client.Timeout = defaultTimeout // 发送请求并获取响应,确保服务器地址包含端口号 response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server)) responses <- serverResponse{response, rtt, server, err} // 将客户端实例放回池中 s.clientPool.Put(client) }(upstream) } // 等待所有请求完成或超时 go func() { wg.Wait() close(responses) }() // 收集成功响应和NXDOMAIN响应分开 var successResponses []*dns.Msg var nxdomainResponses []*dns.Msg var totalRtt time.Duration var responseCount int // 处理所有响应 for resp := range responses { if resp.error == nil && resp.response != nil { // 更新服务器统计信息 s.updateServerStats(resp.server, true, resp.rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(resp.response) // 对于不验证DNSSEC的域名,始终设置AD标志为false if noDNSSEC { resp.response.AuthenticatedData = false } // 只对将要返回的响应进行DNSSEC验证,减少开销 // 这里只设置containsDNSSEC标志,实际验证在确定返回响应后进行 if containsDNSSEC && s.config.EnableDNSSEC && !noDNSSEC { // 暂时不验证,只标记 } // 检查当前服务器是否是DNSSEC专用服务器(O(1)查找) if _, isDNSSECServer := s.dnssecServerMap[resp.server]; isDNSSECServer { usedDNSSECServer = resp.server } // 收集响应,按Rcode分类 if resp.response.Rcode == dns.RcodeSuccess { successResponses = append(successResponses, resp.response) totalRtt += resp.rtt responseCount++ // 记录使用的服务器 if usedDNSServer == "" { usedDNSServer = resp.server } } else if resp.response.Rcode == dns.RcodeNameError { nxdomainResponses = append(nxdomainResponses, resp.response) } else { // 更新备选响应,确保总有一个可用的响应 if resp.response != nil { if !hasBackup { // 第一次保存备选响应 backupResponse = resp.response backupRtt = resp.rtt hasBackup = true } } } } else { // 更新服务器统计信息(失败) s.updateServerStats(resp.server, false, 0) } } // 合并响应:优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应 var validResponses []*dns.Msg if len(successResponses) > 0 { validResponses = successResponses } else { validResponses = nxdomainResponses } // 合并所有有效响应 if len(validResponses) > 0 { bestResponse = mergeResponses(validResponses) if responseCount > 0 { bestRtt = totalRtt / time.Duration(responseCount) } hasBestResponse = true // 设置日志的type字段 logType := "success" if len(successResponses) == 0 { logType = "nxdomain" } logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses), "type", logType) } case "fastest-ip": // 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器 // 1. 选择最快的服务器 fastestServer := s.selectFastestServer(selectedUpstreamDNS) if fastestServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(fastestServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{resp, r, e} }() var response *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan response, rtt, err = result.response, result.rtt, result.err if err == nil && response != nil { // 更新服务器统计信息 s.updateServerStats(fastestServer, true, rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(response) // 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名 // 但如果域名匹配不验证DNSSEC的模式,则跳过验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(response) // 设置AD标志(Authenticated Data) response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } else if noDNSSEC { // 对于不验证DNSSEC的域名,始终设置AD标志为false response.AuthenticatedData = false } // 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应 if response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError { if response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC记录的响应 if containsDNSSEC { bestResponse = response bestRtt = rtt hasBestResponse = true hasDNSSECResponse = true usedDNSServer = fastestServer if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer { usedDNSSECServer = fastestServer } logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt) } else { // 没有带DNSSEC的响应时,保存成功响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = fastestServer if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer { usedDNSSECServer = fastestServer } logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt) } } else if response.Rcode == dns.RcodeNameError { // 处理NXDOMAIN响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = fastestServer logger.Debug("找到NXDOMAIN响应", "domain", domain, "server", fastestServer, "rtt", rtt) } // 保存为备选响应 if !hasBackup { backupResponse = response backupRtt = rtt hasBackup = true } } } else { // 更新服务器统计信息(失败) s.updateServerStats(fastestServer, false, 0) } } default: // 默认使用并行请求模式 - 实现快速返回和超时机制 responses := make(chan serverResponse, len(selectedUpstreamDNS)) resultChan := make(chan struct { response *dns.Msg rtt time.Duration usedServer string usedDnssecServer string }, 1) var wg sync.WaitGroup // 向所有上游服务器并行发送请求 for _, upstream := range selectedUpstreamDNS { wg.Add(1) go func(server string) { defer wg.Done() // 创建带有超时的resolver client := &dns.Client{ Net: s.resolver.Net, UDPSize: s.resolver.UDPSize, Timeout: defaultTimeout, } // 发送请求并获取响应,确保服务器地址包含端口号 response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server)) responses <- serverResponse{response, rtt, server, err} }(upstream) } // 处理响应的协程 go func() { var fastestResponse *dns.Msg var fastestRtt time.Duration = defaultTimeout var fastestServer string var fastestDnssecServer string var fastestHasDnssec bool var successResponses []*dns.Msg var nxdomainResponses []*dns.Msg // 等待所有请求完成或超时 timer := time.NewTimer(defaultTimeout) defer timer.Stop() // 处理所有响应 for { select { case resp, ok := <-responses: if !ok { // 所有响应都已处理 goto doneProcessing } if resp.error == nil && resp.response != nil { // 更新服务器统计信息 s.updateServerStats(resp.server, true, resp.rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(resp.response) // 对于不验证DNSSEC的域名,始终设置AD标志为false if noDNSSEC { resp.response.AuthenticatedData = false } dnssecServerForResponse := "" if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(resp.server)]; isDNSSECServer { dnssecServerForResponse = resp.server } // 如果响应成功或为NXDOMAIN if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError { // 按Rcode分类添加到不同列表 if resp.response.Rcode == dns.RcodeSuccess { successResponses = append(successResponses, resp.response) } else { nxdomainResponses = append(nxdomainResponses, resp.response) } // 快速返回逻辑:找到第一个有效响应或更快的响应 if resp.response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC的响应 if containsDNSSEC { // 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快 if !fastestHasDnssec || resp.rtt < fastestRtt { fastestResponse = resp.response fastestRtt = resp.rtt fastestServer = resp.server fastestDnssecServer = dnssecServerForResponse fastestHasDnssec = true // 只对将要返回的响应进行DNSSEC验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(fastestResponse) // 设置AD标志(Authenticated Data) fastestResponse.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } // 发送结果,快速返回 resultChan <- struct { response *dns.Msg rtt time.Duration usedServer string usedDnssecServer string }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} } } else { // 非DNSSEC响应,只有在还没有找到DNSSEC响应且当前响应更快时才更新 if !fastestHasDnssec && resp.rtt < fastestRtt { fastestResponse = resp.response fastestRtt = resp.rtt fastestServer = resp.server fastestDnssecServer = dnssecServerForResponse // 检查是否包含DNSSEC记录 respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse) // 只对将要返回的响应进行DNSSEC验证 if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(fastestResponse) // 设置AD标志(Authenticated Data) fastestResponse.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } // 发送结果,快速返回 resultChan <- struct { response *dns.Msg rtt time.Duration usedServer string usedDnssecServer string }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} } } } else if resp.response.Rcode == dns.RcodeNameError { // NXDOMAIN响应,只有在还没有找到响应或当前响应更快时才更新 if !fastestHasDnssec && resp.rtt < fastestRtt { fastestResponse = resp.response fastestRtt = resp.rtt fastestServer = resp.server fastestDnssecServer = dnssecServerForResponse // 检查是否包含DNSSEC记录 respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse) // 只对将要返回的响应进行DNSSEC验证 if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(fastestResponse) // 设置AD标志(Authenticated Data) fastestResponse.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } // 发送结果,快速返回 resultChan <- struct { response *dns.Msg rtt time.Duration usedServer string usedDnssecServer string }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} } } } else { // 更新备选响应,确保总有一个可用的响应 if resp.response != nil { if !hasBackup { // 第一次保存备选响应 backupResponse = resp.response backupRtt = resp.rtt hasBackup = true } } } } else { // 更新服务器统计信息(失败) s.updateServerStats(resp.server, false, 0) } case <-timer.C: // 超时,停止等待更多响应 goto doneProcessing } } doneProcessing: // 合并响应,优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应 var validResponses []*dns.Msg if len(successResponses) > 0 { validResponses = successResponses } else { validResponses = nxdomainResponses } // 合并所有有效响应,用于缓存 if len(validResponses) > 1 { mergedResponse := mergeResponses(validResponses) if mergedResponse != nil { // 只在合并后的响应比最快响应更好时才使用 mergedHasDnssec := s.hasDNSSECRecords(mergedResponse) if mergedHasDnssec && !fastestHasDnssec { // 合并后的响应有DNSSEC,而最快响应没有,使用合并后的响应 fastestResponse = mergedResponse // 使用最快的Rtt作为合并响应的Rtt fastestHasDnssec = true } } } // 如果还没有发送结果,发送最快的响应 if fastestResponse != nil { resultChan <- struct { response *dns.Msg rtt time.Duration usedServer string usedDnssecServer string }{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer} } close(resultChan) }() // 等待所有请求完成(不阻塞主流程) go func() { wg.Wait() close(responses) }() // 等待结果或超时 select { case result := <-resultChan: // 快速返回结果 bestResponse = result.response bestRtt = result.rtt usedDNSServer = result.usedServer usedDNSSECServer = result.usedDnssecServer hasBestResponse = true hasDNSSECResponse = s.hasDNSSECRecords(result.response) logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse) case <-time.After(defaultTimeout): // 超时,使用备选响应 logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout) } } // 2. 当启用DNSSEC且没有找到带DNSSEC的响应时,向DNSSEC专用服务器发送请求 // 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains,则不使用DNSSEC专用服务器,只使用指定的DNS服务器 if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC { logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain) // 增加DNSSEC查询计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ }) // 无论查询模式是什么,DNSSEC验证都只使用加权随机选择一个服务器 selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers) if selectedDnssecServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { // 创建带有超时的resolver client := &dns.Client{ Net: s.resolver.Net, UDPSize: s.resolver.UDPSize, Timeout: defaultTimeout, } response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{response, rtt, err} }() var response *dns.Msg var rtt time.Duration var err error // 使用超时获取结果 select { case result := <-resultChan: response, rtt, err = result.response, result.rtt, result.err case <-time.After(defaultTimeout): // 超时,不再等待 logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout) return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer } if err == nil && response != nil { // 更新服务器统计信息 s.updateServerStats(selectedDnssecServer, true, rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(response) if response.Rcode == dns.RcodeSuccess { // 无论响应是否包含DNSSEC记录,只要使用了DNSSEC专用服务器,就设置usedDNSSECServer usedDNSSECServer = selectedDnssecServer // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(response) // 设置AD标志(Authenticated Data) response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECFailed++ }) } // 优先使用DNSSEC专用服务器的响应,尤其是带有DNSSEC记录的 if containsDNSSEC { bestResponse = response bestRtt = rtt hasBestResponse = true hasDNSSECResponse = true logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应,优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt) } // 注意:如果DNSSEC专用服务器返回的响应不包含DNSSEC记录, // 我们不会覆盖之前从upstreamDNS获取的响应, // 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求 // 更新备选响应 if !hasBackup { backupResponse = response backupRtt = rtt hasBackup = true } } } else { // 更新服务器统计信息(失败) s.updateServerStats(selectedDnssecServer, false, 0) } } } // 3. 返回最佳响应 if hasBestResponse { // 检查最佳响应是否包含DNSSEC记录 bestHasDNSSEC := s.hasDNSSECRecords(bestResponse) // 如果启用了DNSSEC且最佳响应不包含DNSSEC记录,尝试使用本地解析(使用upstreamDNS服务器) // 但如果域名匹配了domainSpecificDNS配置,则不执行此逻辑,只使用指定的DNS服务器 if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched { logger.Debug("最佳响应不包含DNSSEC记录,尝试使用本地解析(upstreamDNS)", "domain", domain) // 选择一个upstreamDNS服务器进行解析(使用加权随机算法) localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS) if localServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{resp, r, e} }() var localResponse *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan localResponse, rtt, err = result.response, result.rtt, result.err if err == nil && localResponse != nil { // 更新服务器统计信息 s.updateServerStats(localServer, true, rtt) // 检查是否包含DNSSEC记录 localHasDNSSEC := s.hasDNSSECRecords(localResponse) // 验证DNSSEC记录(如果存在),但不影响最终响应 if localHasDNSSEC { signatureValid := s.verifyDNSSEC(localResponse) localResponse.AuthenticatedData = signatureValid if signatureValid { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新域名的DNSSEC状态 s.updateDomainDNSSECStatus(domain, localHasDNSSEC) s.updateStats(func(stats *Stats) { stats.Allowed++ }) logger.Debug("使用本地解析结果(upstreamDNS)", "domain", domain, "server", localServer, "rtt", rtt) return localResponse, rtt, localServer, "" } else { // 更新服务器统计信息(失败) s.updateServerStats(localServer, false, 0) } } } // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新域名的DNSSEC状态 if bestHasDNSSEC { s.updateDomainDNSSECStatus(domain, true) } else { s.updateDomainDNSSECStatus(domain, false) } s.updateStats(func(stats *Stats) { stats.Allowed++ }) return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer } // 如果有备选响应,返回该响应 if hasBackup { logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain) // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新统计信息 s.updateStats(func(stats *Stats) { stats.Allowed++ }) return backupResponse, backupRtt, "", "" } // 所有上游服务器都失败,返回服务器失败错误 response := new(dns.Msg) response.SetReply(r) response.SetRcode(r, dns.RcodeServerFailure) logger.Error("DNS查询失败", "domain", domain) s.updateStats(func(stats *Stats) { stats.Errors++ }) return response, 0, "", "" } // forwardDNSRequest 转发DNS请求到上游服务器 func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) { response, _, _, _ := s.forwardDNSRequestWithCache(r, domain) w.WriteMsg(response) } // updateBlockedDomainStats 更新被屏蔽域名统计 func (s *Server) updateBlockedDomainStats(domain string) { // 先尝试读锁,检查条目是否存在 s.blockedDomainsMutex.RLock() entry, exists := s.blockedDomains[domain] s.blockedDomainsMutex.RUnlock() if exists { // 使用原子操作更新计数和时间戳 atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { // 获取写锁,创建新条目 s.blockedDomainsMutex.Lock() // 再次检查,避免竞态条件 if entry, exists := s.blockedDomains[domain]; exists { atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { s.blockedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now().UnixNano(), DNSSEC: false, } } s.blockedDomainsMutex.Unlock() } // 更新统计数据 now := time.Now() // 更新小时统计 hourKey := now.Format("2006-01-02-15") s.hourlyStatsMutex.Lock() s.hourlyStats[hourKey]++ s.hourlyStatsMutex.Unlock() // 更新每日统计 dayKey := now.Format("2006-01-02") s.dailyStatsMutex.Lock() s.dailyStats[dayKey]++ s.dailyStatsMutex.Unlock() // 更新每月统计 monthKey := now.Format("2006-01") s.monthlyStatsMutex.Lock() s.monthlyStats[monthKey]++ s.monthlyStatsMutex.Unlock() } // updateClientStats 更新客户端统计 func (s *Server) updateClientStats(ip string) { // 先尝试读锁,检查条目是否存在 s.clientStatsMutex.RLock() entry, exists := s.clientStats[ip] s.clientStatsMutex.RUnlock() if exists { // 使用原子操作更新计数和时间戳 atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { // 获取写锁,创建新条目 s.clientStatsMutex.Lock() // 再次检查,避免竞态条件 if entry, exists := s.clientStats[ip]; exists { atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { s.clientStats[ip] = &ClientStats{ IP: ip, Count: 1, LastSeen: time.Now().UnixNano(), } } s.clientStatsMutex.Unlock() } } // hasDNSSECRecords 检查响应是否包含DNSSEC记录 func (s *Server) hasDNSSECRecords(response *dns.Msg) bool { // 直接调用包内的hasDNSSECRecords函数,避免重复代码 return hasDNSSECRecords(response) } // verifyDNSSEC 验证DNSSEC签名 func (s *Server) verifyDNSSEC(response *dns.Msg) bool { // 提取DNSKEY和RRSIG记录,并按类型和名称组织记录 dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY rrsigs := make([]*dns.RRSIG, 0) // 按 (名称, 类型) 组织记录集,用于快速查找 rrSets := make(map[string]map[uint16][]dns.RR) // name -> type -> records // 定义处理单个记录的函数 processRecord := func(rr dns.RR) { num := rr.Header().Rrtype name := rr.Header().Name // 组织记录集 if _, exists := rrSets[name]; !exists { rrSets[name] = make(map[uint16][]dns.RR) } if _, exists := rrSets[name][num]; !exists { rrSets[name][num] = make([]dns.RR, 0) } rrSets[name][num] = append(rrSets[name][num], rr) // 特别处理DNSKEY和RRSIG if dnskey, ok := rr.(*dns.DNSKEY); ok { tag := dnskey.KeyTag() dnskeys[tag] = dnskey } else if rrsig, ok := rr.(*dns.RRSIG); ok { rrsigs = append(rrsigs, rrsig) } } // 一次遍历所有响应部分,同时完成记录收集和组织 for _, rr := range response.Answer { processRecord(rr) } for _, rr := range response.Ns { processRecord(rr) } for _, rr := range response.Extra { processRecord(rr) } // 如果没有RRSIG记录,验证失败 if len(rrsigs) == 0 { return false } // 验证所有RRSIG记录 signatureValid := true // 用于记录已经警告过的DNSKEY tag,避免重复警告 warnedKeyTags := make(map[uint16]bool) for _, rrsig := range rrsigs { // 查找对应的DNSKEY dnskey, exists := dnskeys[rrsig.KeyTag] if !exists { // 仅当该key_tag尚未警告过时才记录警告 if !warnedKeyTags[rrsig.KeyTag] { logger.Warn("DNSSEC验证失败:找不到对应的DNSKEY", "key_tag", rrsig.KeyTag) warnedKeyTags[rrsig.KeyTag] = true } signatureValid = false continue } // 快速查找需要验证的记录集 name := rrsig.Header().Name typeCovered := rrsig.TypeCovered rrset := rrSets[name][typeCovered] // 验证签名 if len(rrset) > 0 { err := rrsig.Verify(dnskey, rrset) if err != nil { logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag) signatureValid = false } else { logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag) } } } return signatureValid } // updateDomainDNSSECStatus 更新域名的DNSSEC状态 func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) { // 确保域名是小写 domain = strings.ToLower(domain) // 更新域名的DNSSEC状态 s.resolvedDomainsMutex.Lock() defer s.resolvedDomainsMutex.Unlock() // 更新resolvedDomains中的DNSSEC状态 if entry, exists := s.resolvedDomains[domain]; exists { entry.DNSSEC = dnssec } else { s.resolvedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now().UnixNano(), DNSSEC: dnssec, } } // 更新domainDNSSECStatus映射(使用单独的锁) s.domainDNSSECStatusMutex.Lock() s.domainDNSSECStatus[domain] = dnssec s.domainDNSSECStatusMutex.Unlock() } // updateResolvedDomainStats 更新解析域名统计 func (s *Server) updateResolvedDomainStats(domain string) { // 先尝试读锁,检查条目是否存在 s.resolvedDomainsMutex.RLock() entry, exists := s.resolvedDomains[domain] s.resolvedDomainsMutex.RUnlock() if exists { // 使用原子操作更新计数和时间戳 atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { // 获取写锁,创建新条目 s.resolvedDomainsMutex.Lock() // 再次检查,避免竞态条件 if entry, exists := s.resolvedDomains[domain]; exists { atomic.AddInt64(&entry.Count, 1) atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano()) } else { s.resolvedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now().UnixNano(), DNSSEC: false, } } s.resolvedDomainsMutex.Unlock() } } // getServerStats 获取服务器统计信息,如果不存在则创建 func (s *Server) getServerStats(server string) *ServerStats { s.serverStatsMutex.RLock() stats, exists := s.serverStats[server] s.serverStatsMutex.RUnlock() if exists { return stats } s.serverStatsMutex.Lock() defer s.serverStatsMutex.Unlock() if stats, exists := s.serverStats[server]; exists { return stats } stats = &ServerStats{ SuccessCount: 0, FailureCount: 0, LastResponse: time.Now(), ResponseTime: 0, ConnectionSpeed: 0, } s.serverStats[server] = stats return stats } // updateServerStats 更新服务器统计信息 func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) { stats := s.getServerStats(server) // 使用原子操作更新成功和失败计数 if success { successCount := atomic.AddInt64(&stats.SuccessCount, 1) // 只在需要更新平均响应时间时获取锁 s.serverStatsMutex.Lock() stats.LastResponse = time.Now() // 更新平均响应时间(简单移动平均) if successCount == 1 { // 第一次成功,直接使用当前响应时间 stats.ResponseTime = rtt } else { // 使用纳秒进行计算以避免类型不匹配 prevTotal := stats.ResponseTime.Nanoseconds() * (successCount - 1) newTotal := prevTotal + rtt.Nanoseconds() stats.ResponseTime = time.Duration(newTotal / successCount) } s.serverStatsMutex.Unlock() } else { atomic.AddInt64(&stats.FailureCount, 1) // 只更新LastResponse时获取锁 s.serverStatsMutex.Lock() stats.LastResponse = time.Now() s.serverStatsMutex.Unlock() } } // selectWeightedRandomServer 加权随机选择服务器 func (s *Server) selectWeightedRandomServer(servers []string) string { if len(servers) == 0 { return "" } if len(servers) == 1 { return servers[0] } type serverWeight struct { server string weight int64 responseTime time.Duration successCount int64 failureCount int64 } var totalWeight int64 var totalResponseTime time.Duration var validServers int var currentWeight int64 serversInfo := make([]serverWeight, len(servers)) for i, server := range servers { stats := s.getServerStats(server) serversInfo[i] = serverWeight{ server: server, responseTime: stats.ResponseTime, successCount: atomic.LoadInt64(&stats.SuccessCount), failureCount: atomic.LoadInt64(&stats.FailureCount), } if stats.ResponseTime > 0 { totalResponseTime += stats.ResponseTime validServers++ } } var avgResponseTime time.Duration if validServers > 0 { avgResponseTime = totalResponseTime / time.Duration(validServers) } else { avgResponseTime = 1 * time.Second } var randomGen = rand.New(rand.NewSource(time.Now().UnixNano())) for i := range serversInfo { baseWeight := serversInfo[i].successCount - serversInfo[i].failureCount*2 if baseWeight < 1 { baseWeight = 1 } var responseFactor int64 = 100 if serversInfo[i].responseTime > 0 { if serversInfo[i].responseTime < avgResponseTime { factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds() if factor > 200 { factor = 200 } responseFactor = factor } else { factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds() if factor < 50 { factor = 50 } responseFactor = factor } } finalWeight := (baseWeight * responseFactor) / 100 if finalWeight < 1 { finalWeight = 1 } serversInfo[i].weight = finalWeight totalWeight += finalWeight } random := randomGen.Int63n(totalWeight) for _, sw := range serversInfo { currentWeight += sw.weight if random < currentWeight { return sw.server } } // 兜底返回第一个服务器 return servers[0] } // measureServerSpeed 测量服务器TCP连接速度 func (s *Server) measureServerSpeed(server string) time.Duration { addr := server if !strings.Contains(server, ":") { addr = server + ":53" } startTime := time.Now() conn, err := net.DialTimeout("tcp", addr, 2*time.Second) if err != nil { return 2 * time.Second } defer conn.Close() connTime := time.Since(startTime) stats := s.getServerStats(server) s.serverStatsMutex.Lock() stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4 s.serverStatsMutex.Unlock() return connTime } // selectFastestServer 选择连接速度最快的服务器 func (s *Server) selectFastestServer(servers []string) string { if len(servers) == 0 { return "" } if len(servers) == 1 { return servers[0] } // 并行测量所有服务器的速度 type speedResult struct { server string speed time.Duration } results := make(chan speedResult, len(servers)) var wg sync.WaitGroup for _, server := range servers { wg.Add(1) go func(srv string) { defer wg.Done() speed := s.measureServerSpeed(srv) results <- speedResult{srv, speed} }(server) } // 等待所有测量完成 go func() { wg.Wait() close(results) }() // 找出最快的服务器 var fastestServer string var fastestSpeed time.Duration = 2 * time.Second for result := range results { if result.speed < fastestSpeed { fastestSpeed = result.speed fastestServer = result.server } } // 如果没有找到最快服务器(理论上不会发生),返回第一个服务器 if fastestServer == "" { fastestServer = servers[0] } return fastestServer } // updateStats 更新统计信息 func (s *Server) updateStats(update func(*Stats)) { s.statsMutex.Lock() defer s.statsMutex.Unlock() update(s.stats) } // addQueryLog 添加查询日志 func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) { // 创建日志记录 log := QueryLog{ Timestamp: time.Now(), ClientIP: clientIP, Domain: domain, QueryType: queryType, ResponseTime: responseTime, Result: result, BlockRule: blockRule, BlockType: blockType, FromCache: fromCache, DNSSEC: dnssec, EDNS: edns, DNSServer: dnsServer, DNSSECServer: dnssecServer, Answers: answers, ResponseCode: responseCode, } // 发送到日志处理通道(非阻塞) select { case s.logChannel <- log: // 日志发送成功 default: // 通道已满,丢弃日志以避免阻塞请求处理 logger.Warn("日志通道已满,丢弃一条日志记录") } } // GetStartTime 获取服务器启动时间 func (s *Server) GetStartTime() time.Time { return s.startTime } // GetStats 获取DNS服务器统计信息 func (s *Server) GetStats() *Stats { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 复制查询类型统计 queryTypesCopy := make(map[string]int64) for k, v := range s.stats.QueryTypes { queryTypesCopy[k] = v } // 复制来源IP统计 sourceIPsCopy := make(map[string]bool) for ip := range s.stats.SourceIPs { sourceIPsCopy[ip] = true } // 返回统计信息的副本 return &Stats{ Queries: s.stats.Queries, Blocked: s.stats.Blocked, Allowed: s.stats.Allowed, Errors: s.stats.Errors, LastQuery: s.stats.LastQuery, AvgResponseTime: s.stats.AvgResponseTime, TotalResponseTime: s.stats.TotalResponseTime, QueryTypes: queryTypesCopy, SourceIPs: sourceIPsCopy, CpuUsage: s.stats.CpuUsage, DNSSECQueries: s.stats.DNSSECQueries, DNSSECSuccess: s.stats.DNSSECSuccess, DNSSECFailed: s.stats.DNSSECFailed, DNSSECEnabled: s.stats.DNSSECEnabled, } } // GetQueryLogs 获取查询日志 func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog { s.queryLogsMutex.RLock() defer s.queryLogsMutex.RUnlock() // 确保偏移量和限制值合理 if offset < 0 { offset = 0 } if limit <= 0 { limit = 100 // 默认返回100条日志 } // 创建日志副本用于过滤和排序 var logsCopy []QueryLog // 先过滤日志 for _, log := range s.queryLogs { // 应用结果过滤 if resultFilter != "" && log.Result != resultFilter { continue } // 应用搜索过滤 if searchTerm != "" { // 搜索域名或客户端IP if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) { continue } } logsCopy = append(logsCopy, log) } // 排序日志 if sortField != "" { sort.Slice(logsCopy, func(i, j int) bool { var a, b interface{} switch sortField { case "time": a = logsCopy[i].Timestamp b = logsCopy[j].Timestamp case "clientIp": a = logsCopy[i].ClientIP b = logsCopy[j].ClientIP case "domain": a = logsCopy[i].Domain b = logsCopy[j].Domain case "responseTime": a = logsCopy[i].ResponseTime b = logsCopy[j].ResponseTime case "blockRule": a = logsCopy[i].BlockRule b = logsCopy[j].BlockRule default: // 默认按时间排序 a = logsCopy[i].Timestamp b = logsCopy[j].Timestamp } // 根据排序方向比较 if sortDirection == "asc" { return compareValues(a, b) < 0 } return compareValues(a, b) > 0 }) } // 计算返回范围 start := offset end := offset + limit if end > len(logsCopy) { end = len(logsCopy) } if start >= len(logsCopy) { return []QueryLog{} // 没有数据,返回空切片 } return logsCopy[start:end] } // compareValues 比较两个值 func compareValues(a, b interface{}) int { switch v1 := a.(type) { case time.Time: v2 := b.(time.Time) if v1.Before(v2) { return -1 } if v1.After(v2) { return 1 } return 0 case string: v2 := b.(string) if v1 < v2 { return -1 } if v1 > v2 { return 1 } return 0 case int64: v2 := b.(int64) if v1 < v2 { return -1 } if v1 > v2 { return 1 } return 0 default: return 0 } } // GetQueryLogsCount 获取查询日志总数 func (s *Server) GetQueryLogsCount() int { s.queryLogsMutex.RLock() defer s.queryLogsMutex.RUnlock() return len(s.queryLogs) } // GetQueryStats 获取查询统计信息 func (s *Server) GetQueryStats() map[string]interface{} { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 计算统计数据 return map[string]interface{}{ "totalQueries": s.stats.Queries, "blockedQueries": s.stats.Blocked, "allowedQueries": s.stats.Allowed, "errorQueries": s.stats.Errors, "avgResponseTime": s.stats.AvgResponseTime, "activeIPs": len(s.stats.SourceIPs), } } // GetTopBlockedDomains 获取TOP屏蔽域名列表 func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按计数排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetTopResolvedDomains 获取TOP解析域名 func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain { s.resolvedDomainsMutex.RLock() defer s.resolvedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.resolvedDomains)) for _, entry := range s.resolvedDomains { domains = append(domains, *entry) } // 按数量排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetRecentBlockedDomains 获取最近屏蔽的域名列表 func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按时间排序 sort.Slice(domains, func(i, j int) bool { return domains[i].LastSeen > domains[j].LastSeen }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetTopClients 获取TOP客户端列表 func (s *Server) GetTopClients(limit int) []ClientStats { s.clientStatsMutex.RLock() defer s.clientStatsMutex.RUnlock() // 转换为切片 clients := make([]ClientStats, 0, len(s.clientStats)) for _, entry := range s.clientStats { clients = append(clients, *entry) } // 按请求次数排序 sort.Slice(clients, func(i, j int) bool { return clients[i].Count > clients[j].Count }) // 返回限制数量 if len(clients) > limit { return clients[:limit] } return clients } // GetHourlyStats 获取每小时统计数据 func (s *Server) GetHourlyStats() map[string]int64 { s.hourlyStatsMutex.RLock() defer s.hourlyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.hourlyStats { result[k] = v } return result } // GetDailyStats 获取每日统计数据 func (s *Server) GetDailyStats() map[string]int64 { s.dailyStatsMutex.RLock() defer s.dailyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.dailyStats { result[k] = v } return result } // GetMonthlyStats 获取每月统计数据 func (s *Server) GetMonthlyStats() map[string]int64 { s.monthlyStatsMutex.RLock() defer s.monthlyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.monthlyStats { result[k] = v } return result } // isPrivateIP 检测IP地址是否为内网IP func isPrivateIP(ip string) bool { // 解析IP地址 parsedIP := net.ParseIP(ip) if parsedIP == nil { return false } // 检查IPv4内网地址 if ipv4 := parsedIP.To4(); ipv4 != nil { // 10.0.0.0/8 if ipv4[0] == 10 { return true } // 172.16.0.0/12 if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) { return true } // 192.168.0.0/16 if ipv4[0] == 192 && ipv4[1] == 168 { return true } // 127.0.0.0/8 (localhost) if ipv4[0] == 127 { return true } // 169.254.0.0/16 (链路本地地址) if ipv4[0] == 169 && ipv4[1] == 254 { return true } return false } // 检查IPv6内网地址 // ::1/128 (localhost) if parsedIP.IsLoopback() { return true } // fc00::/7 (唯一本地地址) if parsedIP[0]&0xfc == 0xfc { return true } // fe80::/10 (链路本地地址) if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 { return true } return false } // loadStatsData 从文件加载统计数据 func (s *Server) loadStatsData() { // 检查文件是否存在 data, err := ioutil.ReadFile("data/stats.json") if err != nil { if !os.IsNotExist(err) { logger.Error("读取统计数据文件失败", "error", err) } return } var statsData StatsData err = json.Unmarshal(data, &statsData) if err != nil { logger.Error("解析统计数据失败", "error", err) return } // 恢复统计数据 s.statsMutex.Lock() if statsData.Stats != nil { // 只恢复有效数据,避免破坏统计关系 s.stats.Queries += statsData.Stats.Queries s.stats.Blocked += statsData.Stats.Blocked s.stats.Allowed += statsData.Stats.Allowed s.stats.Errors += statsData.Stats.Errors s.stats.TotalResponseTime += statsData.Stats.TotalResponseTime s.stats.DNSSECQueries += statsData.Stats.DNSSECQueries s.stats.DNSSECSuccess += statsData.Stats.DNSSECSuccess s.stats.DNSSECFailed += statsData.Stats.DNSSECFailed // 重新计算平均响应时间,确保一致性 if s.stats.Queries > 0 { s.stats.AvgResponseTime = float64(s.stats.TotalResponseTime) / float64(s.stats.Queries) // 限制平均响应时间的范围,避免显示异常大的值 if s.stats.AvgResponseTime > 60000 { s.stats.AvgResponseTime = 60000 } } // 合并查询类型统计 for k, v := range statsData.Stats.QueryTypes { s.stats.QueryTypes[k] += v } // 合并来源IP统计 for ip := range statsData.Stats.SourceIPs { s.stats.SourceIPs[ip] = true } // 确保使用当前配置中的EnableDNSSEC值 s.stats.DNSSECEnabled = s.config.EnableDNSSEC } s.statsMutex.Unlock() s.blockedDomainsMutex.Lock() if statsData.BlockedDomains != nil { s.blockedDomains = statsData.BlockedDomains } s.blockedDomainsMutex.Unlock() s.resolvedDomainsMutex.Lock() if statsData.ResolvedDomains != nil { s.resolvedDomains = statsData.ResolvedDomains } s.resolvedDomainsMutex.Unlock() s.hourlyStatsMutex.Lock() if statsData.HourlyStats != nil { s.hourlyStats = statsData.HourlyStats } s.hourlyStatsMutex.Unlock() s.dailyStatsMutex.Lock() if statsData.DailyStats != nil { s.dailyStats = statsData.DailyStats } s.dailyStatsMutex.Unlock() s.monthlyStatsMutex.Lock() if statsData.MonthlyStats != nil { s.monthlyStats = statsData.MonthlyStats } s.monthlyStatsMutex.Unlock() // 加载客户端统计数据 s.clientStatsMutex.Lock() if statsData.ClientStats != nil { s.clientStats = statsData.ClientStats } s.clientStatsMutex.Unlock() logger.Info("统计数据加载成功") // 加载查询日志 s.loadQueryLogs() } // loadQueryLogs 从文件加载查询日志 func (s *Server) loadQueryLogs() { // 获取绝对路径 statsFilePath, err := filepath.Abs("data/stats.json") if err != nil { logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err) return } // 构建查询日志文件路径 queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json") // 检查文件是否存在 if _, err := os.Stat(queryLogPath); os.IsNotExist(err) { logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath) return } // 读取文件内容 data, err := ioutil.ReadFile(queryLogPath) if err != nil { logger.Error("读取查询日志文件失败", "error", err) return } // 解析数据 var logs []QueryLog err = json.Unmarshal(data, &logs) if err != nil { logger.Error("解析查询日志失败", "error", err) return } // 更新查询日志 s.queryLogsMutex.Lock() s.queryLogs = logs // 确保日志数量不超过限制 if len(s.queryLogs) > s.maxQueryLogs { s.queryLogs = s.queryLogs[:s.maxQueryLogs] } s.queryLogsMutex.Unlock() logger.Info("查询日志加载成功", "count", len(logs)) } // processLogs 异步处理日志记录 func (s *Server) processLogs() { for { select { case logEntry, ok := <-s.logChannel: if !ok { // 通道关闭,退出循环 return } // 加锁保护queryLogs s.queryLogsMutex.Lock() // 如果日志数量超过最大限制,删除最旧的日志 if len(s.queryLogs) >= s.maxQueryLogs { // 保留最新的s.maxQueryLogs条日志 newLogs := make([]QueryLog, 0, s.maxQueryLogs) // 复制最新的日志到新切片 for i := len(s.queryLogs) - s.maxQueryLogs + 1; i < len(s.queryLogs); i++ { newLogs = append(newLogs, s.queryLogs[i]) } // 添加新日志 newLogs = append(newLogs, logEntry) // 替换原有日志 s.queryLogs = newLogs } else { // 直接添加新日志 s.queryLogs = append(s.queryLogs, logEntry) } // 解锁 s.queryLogsMutex.Unlock() case <-s.ctx.Done(): // 上下文取消,退出循环 return } } } // saveStatsData 保存统计数据到文件 func (s *Server) saveStatsData() { // 获取绝对路径以避免工作目录问题 statsFilePath, err := filepath.Abs("data/stats.json") if err != nil { logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err) return } // 创建数据目录 statsDir := filepath.Dir(statsFilePath) err = os.MkdirAll(statsDir, 0755) if err != nil { logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err) return } // 收集所有统计数据 statsData := &StatsData{ Stats: s.GetStats(), LastSaved: time.Now(), } // 复制域名数据 s.blockedDomainsMutex.RLock() statsData.BlockedDomains = make(map[string]*BlockedDomain) for k, v := range s.blockedDomains { statsData.BlockedDomains[k] = v } s.blockedDomainsMutex.RUnlock() s.resolvedDomainsMutex.RLock() statsData.ResolvedDomains = make(map[string]*BlockedDomain) for k, v := range s.resolvedDomains { statsData.ResolvedDomains[k] = v } s.resolvedDomainsMutex.RUnlock() s.hourlyStatsMutex.RLock() statsData.HourlyStats = make(map[string]int64) for k, v := range s.hourlyStats { statsData.HourlyStats[k] = v } s.hourlyStatsMutex.RUnlock() s.dailyStatsMutex.RLock() statsData.DailyStats = make(map[string]int64) for k, v := range s.dailyStats { statsData.DailyStats[k] = v } s.dailyStatsMutex.RUnlock() s.monthlyStatsMutex.RLock() statsData.MonthlyStats = make(map[string]int64) for k, v := range s.monthlyStats { statsData.MonthlyStats[k] = v } s.monthlyStatsMutex.RUnlock() // 复制客户端统计数据 s.clientStatsMutex.RLock() statsData.ClientStats = make(map[string]*ClientStats) for k, v := range s.clientStats { statsData.ClientStats[k] = v } s.clientStatsMutex.RUnlock() // 序列化数据 jsonData, err := json.MarshalIndent(statsData, "", " ") if err != nil { logger.Error("序列化统计数据失败", "error", err) return } // 写入文件 err = os.WriteFile(statsFilePath, jsonData, 0644) if err != nil { logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err) return } logger.Info("统计数据保存成功", "file", statsFilePath) // 保存查询日志到文件 s.saveQueryLogs(statsDir) } // saveQueryLogs 保存查询日志到文件 func (s *Server) saveQueryLogs(dataDir string) { // 构建查询日志文件路径 queryLogPath := filepath.Join(dataDir, "querylog.json") // 获取查询日志数据 s.queryLogsMutex.RLock() logsCopy := make([]QueryLog, len(s.queryLogs)) copy(logsCopy, s.queryLogs) s.queryLogsMutex.RUnlock() // 序列化数据 jsonData, err := json.MarshalIndent(logsCopy, "", " ") if err != nil { logger.Error("序列化查询日志失败", "error", err) return } // 写入文件 err = os.WriteFile(queryLogPath, jsonData, 0644) if err != nil { logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err) return } logger.Info("查询日志保存成功", "file", queryLogPath) } // startCpuUsageMonitor 启动CPU使用率监控 func (s *Server) startCpuUsageMonitor() { ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率 defer ticker.Stop() // 初始化 var memStats runtime.MemStats runtime.ReadMemStats(&memStats) // 存储上一次的CPU时间统计 var prevIdle, prevTotal uint64 for { select { case <-ticker.C: // 获取真实的系统级CPU使用率 cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal) if err != nil { // 如果获取失败,使用默认值 cpuUsage = 0.0 logger.Error("获取系统CPU使用率失败", "error", err) } s.updateStats(func(stats *Stats) { stats.CpuUsage = cpuUsage }) case <-s.ctx.Done(): return } } } // getSystemCpuUsage 获取系统CPU使用率 func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) { // 读取/proc/stat文件获取CPU统计信息 file, err := os.Open("/proc/stat") if err != nil { return 0, err } defer file.Close() var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64 _, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d", &cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal) if err != nil { return 0, err } // 计算总的CPU时间 total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal idle := cpuIdle + cpuIowait // 第一次调用时,只初始化值,不计算使用率 if *prevTotal == 0 || *prevIdle == 0 { *prevIdle = idle *prevTotal = total return 0, nil } // 计算CPU使用率 idleDelta := idle - *prevIdle totalDelta := total - *prevTotal utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100 // 更新上一次的值 *prevIdle = idle *prevTotal = total return utilization, nil } // startAutoSave 启动自动保存功能 func (s *Server) startAutoSave() { if s.config.SaveInterval <= 0 { return } // 初始化定时器 s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second) defer s.saveTicker.Stop() logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json") // 定期保存数据 for { select { case <-s.saveTicker.C: s.saveStatsData() case <-s.saveDone: return } } }