package dns import ( "context" "encoding/json" "fmt" "io/ioutil" "net" "net/http" "os" "path/filepath" "runtime" "sort" "strings" "sync" "time" "dns-server/config" "dns-server/logger" "dns-server/shield" "github.com/miekg/dns" ) // 确保DNS服务器地址包含端口号,默认添加53端口 func normalizeDNSServerAddress(address string) string { // 检查地址是否已经包含端口号 if _, _, err := net.SplitHostPort(address); err != nil { // 如果没有端口号,添加默认的53端口 return net.JoinHostPort(address, "53") } // 已经有端口号,直接返回 return address } // BlockedDomain 屏蔽域名统计 type BlockedDomain struct { Domain string Count int64 LastSeen time.Time DNSSEC bool // 是否使用了DNSSEC } // ClientStats 客户端统计 type ClientStats struct { IP string Count int64 LastSeen time.Time } // IPGeolocation IP地理位置信息 type IPGeolocation struct { Country string `json:"country"` // 国家 City string `json:"city"` // 城市 Expiry time.Time `json:"expiry"` // 缓存过期时间 } // DNSAnswer DNS解析记录 type DNSAnswer struct { Type string `json:"type"` // 记录类型 Value string `json:"value"` // 记录值 TTL uint32 `json:"ttl"` // 生存时间 } // QueryLog 查询日志记录 type QueryLog struct { Timestamp time.Time `json:"timestamp"` // 查询时间 ClientIP string `json:"clientIP"` // 客户端IP Location string `json:"location"` // IP地理位置(国家 城市) Domain string `json:"domain"` // 查询域名 QueryType string `json:"queryType"` // 查询类型 ResponseTime int64 `json:"responseTime"` // 响应时间(ms) Result string `json:"result"` // 查询结果(allowed, blocked, error) BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽) BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽) FromCache bool `json:"fromCache"` // 是否来自缓存 DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC EDNS bool `json:"edns"` // 是否使用了EDNS DNSServer string `json:"dnsServer"` // 使用的DNS服务器 DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器 Answers []DNSAnswer `json:"answers"` // 解析记录 ResponseCode int `json:"responseCode"` // DNS响应代码 } // StatsData 用于持久化的统计数据结构 type StatsData struct { Stats *Stats `json:"stats"` BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"` ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"` ClientStats map[string]*ClientStats `json:"clientStats"` HourlyStats map[string]int64 `json:"hourlyStats"` DailyStats map[string]int64 `json:"dailyStats"` MonthlyStats map[string]int64 `json:"monthlyStats"` LastSaved time.Time `json:"lastSaved"` } // ServerStats 服务器统计信息 type ServerStats struct { SuccessCount int64 // 成功查询次数 FailureCount int64 // 失败查询次数 LastResponse time.Time // 最后响应时间 ResponseTime time.Duration // 平均响应时间 ConnectionSpeed time.Duration // TCP连接速度 } // Server DNS服务器 type Server struct { config *config.DNSConfig shieldConfig *config.ShieldConfig shieldManager *shield.ShieldManager server *dns.Server tcpServer *dns.Server resolver *dns.Client ctx context.Context cancel context.CancelFunc statsMutex sync.Mutex stats *Stats blockedDomainsMutex sync.RWMutex blockedDomains map[string]*BlockedDomain resolvedDomainsMutex sync.RWMutex resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名 clientStatsMutex sync.RWMutex clientStats map[string]*ClientStats // 用于记录客户端统计 hourlyStatsMutex sync.RWMutex hourlyStats map[string]int64 // 按小时统计屏蔽数量 dailyStatsMutex sync.RWMutex dailyStats map[string]int64 // 按天统计屏蔽数量 monthlyStatsMutex sync.RWMutex monthlyStats map[string]int64 // 按月统计屏蔽数量 queryLogsMutex sync.RWMutex queryLogs []QueryLog // 查询日志列表 maxQueryLogs int // 最大保存日志数量 saveTicker *time.Ticker // 用于定时保存数据 startTime time.Time // 服务器启动时间 saveDone chan struct{} // 用于通知保存协程停止 stopped bool // 服务器是否已经停止 stoppedMutex sync.Mutex // 保护stopped标志的互斥锁 // IP地理位置缓存 ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射 ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁 ipGeolocationCacheTTL time.Duration // 缓存有效期 // DNS查询缓存 DnsCache *DNSCache // DNS响应缓存 // 域名DNSSEC状态映射表 domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射 // 上游服务器状态跟踪 serverStats map[string]*ServerStats // 服务器地址到状态的映射 serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁 } // Stats DNS服务器统计信息 type Stats struct { Queries int64 Blocked int64 Allowed int64 Errors int64 LastQuery time.Time AvgResponseTime float64 // 平均响应时间(ms) TotalResponseTime int64 // 总响应时间 QueryTypes map[string]int64 // 查询类型统计 SourceIPs map[string]bool // 活跃来源IP CpuUsage float64 // CPU使用率(%) DNSSECQueries int64 // DNSSEC查询总数 DNSSECSuccess int64 // DNSSEC验证成功数 DNSSECFailed int64 // DNSSEC验证失败数 DNSSECEnabled bool // 是否启用了DNSSEC } // NewServer 创建DNS服务器实例 func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server { ctx, cancel := context.WithCancel(context.Background()) // 从配置中读取DNS缓存TTL值(分钟) cacheTTL := time.Duration(config.CacheTTL) * time.Minute server := &Server{ config: config, shieldConfig: shieldConfig, shieldManager: shieldManager, resolver: &dns.Client{ Net: "udp", UDPSize: 4096, // 增加UDP缓冲区大小,支持更大的DNSSEC响应 }, ctx: ctx, cancel: cancel, startTime: time.Now(), // 记录服务器启动时间 stats: &Stats{ Queries: 0, Blocked: 0, Allowed: 0, Errors: 0, AvgResponseTime: 0, TotalResponseTime: 0, QueryTypes: make(map[string]int64), SourceIPs: make(map[string]bool), CpuUsage: 0, DNSSECQueries: 0, DNSSECSuccess: 0, DNSSECFailed: 0, DNSSECEnabled: config.EnableDNSSEC, }, blockedDomains: make(map[string]*BlockedDomain), resolvedDomains: make(map[string]*BlockedDomain), clientStats: make(map[string]*ClientStats), hourlyStats: make(map[string]int64), dailyStats: make(map[string]int64), monthlyStats: make(map[string]int64), queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000 maxQueryLogs: 10000, // 最大保存10000条日志 saveDone: make(chan struct{}), stopped: false, // 初始化为未停止状态 // IP地理位置缓存初始化 ipGeolocationCache: make(map[string]*IPGeolocation), ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时 // DNS查询缓存初始化 DnsCache: NewDNSCache(cacheTTL), // 初始化域名DNSSEC状态映射表 domainDNSSECStatus: make(map[string]bool), // 初始化服务器状态跟踪 serverStats: make(map[string]*ServerStats), } // 加载已保存的统计数据 server.loadStatsData() return server } // Start 启动DNS服务器 func (s *Server) Start() error { // 重新初始化上下文和取消函数 ctx, cancel := context.WithCancel(context.Background()) s.ctx = ctx s.cancel = cancel // 重新初始化saveDone通道 s.saveDone = make(chan struct{}) // 重置stopped标志 s.stoppedMutex.Lock() s.stopped = false s.stoppedMutex.Unlock() // 更新服务器启动时间 s.startTime = time.Now() s.server = &dns.Server{ Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), Net: "udp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 保存TCP服务器实例,以便在Stop方法中关闭 s.tcpServer = &dns.Server{ Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port), Net: "tcp", Handler: dns.HandlerFunc(s.handleDNSRequest), } // 启动CPU使用率监控 go s.startCpuUsageMonitor() // 启动自动保存功能 go s.startAutoSave() // 启动UDP服务 go func() { logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port)) if err := s.server.ListenAndServe(); err != nil { logger.Error("DNS UDP服务器启动失败", "error", err) s.cancel() } }() // 启动TCP服务 go func() { logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port)) if err := s.tcpServer.ListenAndServe(); err != nil { logger.Error("DNS TCP服务器启动失败", "error", err) s.cancel() } }() // 等待停止信号 <-s.ctx.Done() return nil } // Stop 停止DNS服务器 func (s *Server) Stop() { // 检查服务器是否已经停止 s.stoppedMutex.Lock() if s.stopped { s.stoppedMutex.Unlock() return // 服务器已经停止,直接返回 } // 标记服务器为已停止状态 s.stopped = true s.stoppedMutex.Unlock() // 发送停止信号给保存协程 close(s.saveDone) // 最后保存一次数据 s.saveStatsData() // 停止服务器 s.cancel() if s.server != nil { s.server.Shutdown() } if s.tcpServer != nil { s.tcpServer.Shutdown() } logger.Info("DNS服务器已停止") } // handleDNSRequest 处理DNS请求 func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { startTime := time.Now() // 获取来源IP sourceIP := w.RemoteAddr().String() // 提取IP地址部分,去掉端口 if strings.HasPrefix(sourceIP, "[") { // IPv6地址格式: [::1]:53 if idx := strings.Index(sourceIP, "]"); idx >= 0 { sourceIP = sourceIP[1:idx] // 去掉方括号 } } else { // IPv4地址格式: 127.0.0.1:53 if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 { sourceIP = sourceIP[:idx] } } // 更新来源IP统计 s.updateStats(func(stats *Stats) { stats.Queries++ stats.LastQuery = time.Now() stats.SourceIPs[sourceIP] = true }) // 更新客户端统计 s.updateClientStats(sourceIP) // 获取查询域名和类型 var domain string var queryType string var qType uint16 if len(r.Question) > 0 { domain = r.Question[0].Name // 移除末尾的点 if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } // 获取查询类型 queryType = dns.TypeToString[r.Question[0].Qtype] qType = r.Question[0].Qtype // 更新查询类型统计 s.updateStats(func(stats *Stats) { stats.QueryTypes[queryType]++ }) // 检查是否是AAAA记录查询且IPv6解析已禁用 if qType == dns.TypeAAAA && !s.config.EnableIPv6 { // 返回NXDOMAIN响应(域名不存在) response := new(dns.Msg) response.SetReply(r) response.SetRcode(r, dns.RcodeNameError) w.WriteMsg(response) // 更新统计信息 responseTime := int64(0) s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 添加查询日志 s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeNameError) logger.Debug("IPv6解析已禁用,拒绝AAAA记录查询", "domain", domain) return } } logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr()) // 只处理递归查询 if r.RecursionDesired == false { response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 response.SetRcode(r, dns.RcodeRefused) w.WriteMsg(response) // 缓存命中,响应时间设为0ms responseTime := int64(0) s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 添加查询日志 s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused) return } // 检查hosts文件是否有匹配 if ip, exists := s.shieldManager.GetHostsIP(domain); exists { s.handleHostsResponse(w, r, ip) // 缓存命中,响应时间设为0ms responseTime := int64(0) s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 该方法内部未直接调用addQueryLog,而是在handleDNSRequest中处理 return } // 检查是否被屏蔽 if s.shieldManager.IsBlocked(domain) { // 获取屏蔽详情 blockDetails := s.shieldManager.CheckDomainBlockDetails(domain) blockRule, _ := blockDetails["blockRule"].(string) blockType, _ := blockDetails["blockRuleType"].(string) s.handleBlockedResponse(w, r, domain) // 计算响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 添加查询日志 - 被屏蔽域名 blockedAnswers := []DNSAnswer{} // 根据屏蔽方法确定响应代码 blockedRcode := dns.RcodeNameError // 默认NXDOMAIN if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" { blockedRcode = dns.RcodeRefused } else if blockMethod == "emptyIP" || blockMethod == "customIP" { blockedRcode = dns.RcodeSuccess } s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode) return } // 检查缓存中是否有响应(优先查找带DNSSEC的缓存项) var cachedResponse *dns.Msg var found bool var cachedDNSSEC bool // 1. 首先检查是否有普通缓存项 if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, qType); tempFound { cachedResponse = tempResponse found = tempFound cachedDNSSEC = s.hasDNSSECRecords(tempResponse) } // 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项, // 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录 // (这里可以进一步优化,比如在缓存中标记DNSSEC状态,快速查找) if s.config.EnableDNSSEC && !cachedDNSSEC { // 目前的缓存实现不支持按DNSSEC状态查找,所以这里暂时跳过 // 后续可以考虑改进缓存实现,添加DNSSEC状态标记 } if found { // 缓存命中,直接返回缓存的响应 cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题 cachedResponseCopy.Id = r.Id // 更新ID以匹配请求 cachedResponseCopy.Compress = true // 如果客户端请求包含EDNS记录,确保响应也包含EDNS if opt := r.IsEdns0(); opt != nil { // 检查响应是否已经包含EDNS记录 if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil { // 添加EDNS记录,使用客户端的UDP缓冲区大小 cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } else { // 确保响应的UDP缓冲区大小不超过客户端请求的大小 if respOpt.UDPSize() > opt.UDPSize() { // 移除现有的EDNS记录 for i := range cachedResponseCopy.Extra { if cachedResponseCopy.Extra[i] == respOpt { cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...) break } } // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } } } w.WriteMsg(cachedResponseCopy) // 计算响应时间 responseTime := time.Since(startTime).Milliseconds() s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 如果缓存响应包含DNSSEC记录,更新DNSSEC查询计数 if cachedDNSSEC { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ // 缓存响应视为DNSSEC成功 stats.DNSSECSuccess++ }) } // 从缓存响应中提取解析记录 cachedAnswers := []DNSAnswer{} if cachedResponse != nil { for _, rr := range cachedResponse.Answer { cachedAnswers = append(cachedAnswers, DNSAnswer{ Type: dns.TypeToString[rr.Header().Rrtype], Value: rr.String(), TTL: rr.Header().Ttl, }) } } // 添加查询日志 - 标记为缓存 // 从缓存响应中获取响应代码 cacheRcode := dns.RcodeSuccess // 默认成功 if cachedResponse != nil { cacheRcode = cachedResponse.Rcode } s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode) logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC) return } // 缓存未命中,处理DNS请求 var response *dns.Msg var rtt time.Duration var queryAttempts []string var dnsServer string var dnssecServer string // 直接查询原始域名 queryAttempts = append(queryAttempts, domain) response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, domain) if response != nil { // 如果客户端请求包含EDNS记录,确保响应也包含EDNS if opt := r.IsEdns0(); opt != nil { // 检查响应是否已经包含EDNS记录 if respOpt := response.IsEdns0(); respOpt == nil { // 添加EDNS记录,使用客户端的UDP缓冲区大小 response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } else { // 确保响应的UDP缓冲区大小不超过客户端请求的大小 if respOpt.UDPSize() > opt.UDPSize() { // 移除现有的EDNS记录 for i := range response.Extra { if response.Extra[i] == respOpt { response.Extra = append(response.Extra[:i], response.Extra[i+1:]...) break } } // 添加新的EDNS记录,使用客户端的UDP缓冲区大小 response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC) } } } // 写入响应给客户端 w.WriteMsg(response) } // 使用上游服务器的实际响应时间(转换为毫秒) responseTime := int64(rtt.Milliseconds()) // 如果rtt为0(查询失败),则使用本地计算的时间 if responseTime == 0 { responseTime = time.Since(startTime).Milliseconds() } s.updateStats(func(stats *Stats) { stats.TotalResponseTime += responseTime if stats.Queries > 0 { stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries) } }) // 检查响应是否包含DNSSEC记录并验证结果 responseDNSSEC := false if response != nil { // 使用hasDNSSECRecords函数检查是否包含DNSSEC记录 responseDNSSEC = s.hasDNSSECRecords(response) // 检查AD标志,确认DNSSEC验证是否成功 if response.AuthenticatedData { responseDNSSEC = true } // 更新域名的DNSSEC状态 if responseDNSSEC { s.updateDomainDNSSECStatus(domain, true) } } // 如果响应成功,缓存结果(增强版缓存存储) if response != nil && response.Rcode == dns.RcodeSuccess { // 创建响应副本以避免后续修改影响缓存 responseCopy := response.Copy() // 设置合理的TTL,不超过默认的30分钟 defaultCacheTTL := 30 * time.Minute s.DnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL) logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC) } // 从响应中提取解析记录 responseAnswers := []DNSAnswer{} if response != nil { for _, rr := range response.Answer { responseAnswers = append(responseAnswers, DNSAnswer{ Type: dns.TypeToString[rr.Header().Rrtype], Value: rr.String(), TTL: rr.Header().Ttl, }) } } // 添加查询日志 - 标记为实时 // 从响应中获取响应代码 realRcode := dns.RcodeSuccess // 默认成功 if response != nil { realRcode = response.Rcode } s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode) } // handleHostsResponse 处理hosts文件匹配的响应 func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) { response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 if len(r.Question) > 0 { q := r.Question[0] answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(ip) response.Answer = append(response.Answer, answer) } // 记录解析域名统计 domain := "" if len(r.Question) > 0 { domain = r.Question[0].Name if len(domain) > 0 && domain[len(domain)-1] == '.' { domain = domain[:len(domain)-1] } s.updateResolvedDomainStats(domain) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Allowed++ }) } // handleBlockedResponse 处理被屏蔽的域名响应 func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) { logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr()) // 更新被屏蔽域名统计 s.updateBlockedDomainStats(domain) response := new(dns.Msg) response.SetReply(r) // 不再硬编码RecursionAvailable,使用默认值或上游返回的值 // 获取屏蔽方法配置 blockMethod := "NXDOMAIN" // 默认值 customBlockIP := "" // 默认值 // 从Server结构体的shieldConfig字段获取配置 if s.shieldConfig != nil { blockMethod = s.shieldConfig.BlockMethod customBlockIP = s.shieldConfig.CustomBlockIP } // 根据屏蔽方法返回不同的响应 switch blockMethod { case "refused": // 返回拒绝查询响应 response.SetRcode(r, dns.RcodeRefused) case "emptyIP": // 返回空IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP("0.0.0.0") // 空IP response.Answer = append(response.Answer, answer) } case "customIP": // 返回自定义IP响应 if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" { answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, } answer.A = net.ParseIP(customBlockIP) response.Answer = append(response.Answer, answer) } case "NXDOMAIN", "": fallthrough // 默认使用NXDOMAIN default: // 返回NXDOMAIN响应(域名不存在) response.SetRcode(r, dns.RcodeNameError) } w.WriteMsg(response) s.updateStats(func(stats *Stats) { stats.Blocked++ }) } // forwardDNSRequest 转发DNS请求到上游服务器 // serverResponse 用于存储服务器响应的结构体 type serverResponse struct { response *dns.Msg rtt time.Duration server string error error } // forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应 func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) { // 始终支持EDNS var udpSize uint16 = 4096 var doFlag bool = s.config.EnableDNSSEC // 检查域名是否匹配不验证DNSSEC的模式 noDNSSEC := false for _, pattern := range s.config.NoDNSSECDomains { if strings.Contains(domain, pattern) { noDNSSEC = true doFlag = false logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern) break } } // 检查客户端请求是否包含EDNS记录 if opt := r.IsEdns0(); opt != nil { // 保留客户端的UDP缓冲区大小 udpSize = opt.UDPSize() // 移除现有的EDNS记录,以便重新添加 for i := range r.Extra { if r.Extra[i] == opt { r.Extra = append(r.Extra[:i], r.Extra[i+1:]...) break } } } // 添加EDNS记录,设置适当的UDPSize和DO标志 r.SetEdns0(udpSize, doFlag) // DNSSEC专用服务器列表,从配置中获取 dnssecServers := s.config.DNSSECUpstreamDNS // 选择合适的上游DNS服务器列表 // 1. 首先检查是否有域名特定的DNS服务器配置 var selectedUpstreamDNS []string var domainMatched bool for matchStr, dnsServers := range s.config.DomainSpecificDNS { if strings.Contains(domain, matchStr) { selectedUpstreamDNS = dnsServers domainMatched = true logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers) break } } // 2. 如果没有匹配的域名特定配置 if !domainMatched { // 创建一个新的切片来存储最终的上游服务器列表 var finalUpstreamDNS []string // 首先添加用户配置的上游DNS服务器 finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...) logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS) // 如果启用了DNSSEC且有配置DNSSEC专用服务器,并且域名不匹配NoDNSSECDomains,则将DNSSEC专用服务器添加到列表中 if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC { // 合并DNSSEC专用服务器到上游服务器列表,避免重复,并确保包含端口号 for _, dnssecServer := range s.config.DNSSECUpstreamDNS { hasDuplicate := false // 确保DNSSEC服务器地址包含端口号 normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer) for _, upstream := range finalUpstreamDNS { if upstream == normalizedDnssecServer { hasDuplicate = true break } } if !hasDuplicate { finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer) } } logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS) } // 使用最终合并后的服务器列表 selectedUpstreamDNS = finalUpstreamDNS } // 1. 首先尝试所有配置的上游DNS服务器 var bestResponse *dns.Msg var bestRtt time.Duration var hasBestResponse bool var hasDNSSECResponse bool var backupResponse *dns.Msg var backupRtt time.Duration var hasBackup bool var usedDNSServer string var usedDNSSECServer string // 根据查询模式处理请求 switch s.config.QueryMode { case "parallel": // 并行请求模式 - 优化版:添加超时处理和快速响应返回 responses := make(chan serverResponse, len(selectedUpstreamDNS)) var wg sync.WaitGroup // 向所有上游服务器并行发送请求 for _, upstream := range selectedUpstreamDNS { wg.Add(1) go func(server string) { defer wg.Done() // 发送请求并获取响应,确保服务器地址包含端口号 response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server)) responses <- serverResponse{response, rtt, server, err} }(upstream) } // 等待所有请求完成或超时 go func() { wg.Wait() close(responses) }() // 处理所有响应,实现快速响应返回 for resp := range responses { if resp.error == nil && resp.response != nil { // 更新服务器统计信息 s.updateServerStats(resp.server, true, resp.rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(resp.response) // 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名 // 但如果域名匹配不验证DNSSEC的模式,则跳过验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(resp.response) // 设置AD标志(Authenticated Data) resp.response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } else if noDNSSEC { // 对于不验证DNSSEC的域名,始终设置AD标志为false resp.response.AuthenticatedData = false } // 检查当前服务器是否是DNSSEC专用服务器 for _, dnssecServer := range dnssecServers { if dnssecServer == resp.server { usedDNSSECServer = resp.server break } } // 检查当前服务器是否是用户配置的上游DNS服务器 isUserUpstream := false for _, userServer := range s.config.UpstreamDNS { if userServer == resp.server { isUserUpstream = true break } } // 处理响应,优先选择用户配置的主DNS服务器 if resp.response.Rcode == dns.RcodeSuccess { // 成功响应,优先使用 if isUserUpstream { // 用户配置的主DNS服务器响应,直接设置为最佳响应 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true hasDNSSECResponse = containsDNSSEC usedDNSServer = resp.server logger.Debug("使用用户配置的上游服务器响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:用户配置的主DNS服务器响应,立即返回 continue } else if containsDNSSEC { // 非用户配置服务器,但有DNSSEC记录 if !hasBestResponse || !isUserUpstream { // 如果还没有最佳响应,或者当前最佳响应不是用户配置的服务器,则更新 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true hasDNSSECResponse = true usedDNSServer = resp.server logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:找到带DNSSEC的响应,立即返回 continue } } else { // 非用户配置服务器,没有DNSSEC记录 if !hasBestResponse { // 如果还没有最佳响应,设置为最佳响应 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:第一次找到成功响应,立即返回 continue } } } else if resp.response.Rcode == dns.RcodeNameError { // NXDOMAIN响应 if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError { // 如果还没有最佳响应,或者最佳响应也是NXDOMAIN if isUserUpstream { // 用户配置的服务器,直接使用 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server logger.Debug("使用用户配置的上游服务器NXDOMAIN响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:用户配置的服务器NXDOMAIN响应,立即返回 continue } else if !hasBestResponse || resp.rtt < bestRtt { // 非用户配置服务器,选择更快的响应 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:找到NXDOMAIN响应,立即返回 continue } } } // 更新备选响应,确保总有一个可用的响应 if resp.response != nil { if !hasBackup { // 第一次保存备选响应 backupResponse = resp.response backupRtt = resp.rtt hasBackup = true } else { // 后续响应,优先保存用户配置的服务器响应作为备选 if isUserUpstream { backupResponse = resp.response backupRtt = resp.rtt } } } // 即使响应不是成功或NXDOMAIN,也保存为最佳响应(如果还没有的话) // 确保总有一个响应返回给客户端 if !hasBestResponse { bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server logger.Debug("使用非成功响应作为最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt, "rcode", resp.response.Rcode) } } else { // 更新服务器统计信息(失败) s.updateServerStats(resp.server, false, 0) } } case "loadbalance": // 负载均衡模式 - 使用加权随机选择算法 // 1. 尝试所有可用的服务器,直到找到一个能正常工作的 var triedServers []string for len(triedServers) < len(selectedUpstreamDNS) { // 从剩余的服务器中选择一个加权随机服务器 var availableServers []string for _, server := range selectedUpstreamDNS { found := false for _, tried := range triedServers { if server == tried { found = true break } } if !found { availableServers = append(availableServers, server) } } selectedServer := s.selectWeightedRandomServer(availableServers) if selectedServer == "" { break } triedServers = append(triedServers, selectedServer) logger.Debug("在负载均衡模式下选择服务器", "domain", domain, "server", selectedServer, "triedServers", triedServers) // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(selectedServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{response, rtt, err} }() var response *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan response, rtt, err = result.response, result.rtt, result.err if err == nil && response != nil { // 更新服务器统计信息 s.updateServerStats(selectedServer, true, rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(response) // 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名 // 但如果域名匹配不验证DNSSEC的模式,则跳过验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(response) // 设置AD标志(Authenticated Data) response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } else if noDNSSEC { // 对于不验证DNSSEC的域名,始终设置AD标志为false response.AuthenticatedData = false } // 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应 if response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError { if response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC记录的响应 if containsDNSSEC { bestResponse = response bestRtt = rtt hasBestResponse = true hasDNSSECResponse = true usedDNSServer = selectedServer // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == selectedServer { usedDNSSECServer = selectedServer break } } logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt) } else { // 没有带DNSSEC的响应时,保存成功响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = selectedServer // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == selectedServer { usedDNSSECServer = selectedServer break } } logger.Debug("找到最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt) } } else if response.Rcode == dns.RcodeNameError { // 处理NXDOMAIN响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = selectedServer logger.Debug("找到NXDOMAIN响应", "domain", domain, "server", selectedServer, "rtt", rtt) } // 保存为备选响应 if !hasBackup { backupResponse = response backupRtt = rtt hasBackup = true } break // 找到有效响应,退出循环 } } else { // 更新服务器统计信息(失败) s.updateServerStats(selectedServer, false, 0) logger.Debug("服务器请求失败,尝试下一个", "domain", domain, "server", selectedServer, "error", err) } } case "fastest-ip": // 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器 // 1. 选择最快的服务器 fastestServer := s.selectFastestServer(selectedUpstreamDNS) if fastestServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(fastestServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{resp, r, e} }() var response *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan response, rtt, err = result.response, result.rtt, result.err if err == nil && response != nil { // 更新服务器统计信息 s.updateServerStats(fastestServer, true, rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(response) // 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名 // 但如果域名匹配不验证DNSSEC的模式,则跳过验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(response) // 设置AD标志(Authenticated Data) response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } else if noDNSSEC { // 对于不验证DNSSEC的域名,始终设置AD标志为false response.AuthenticatedData = false } // 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应 if response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError { if response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC记录的响应 if containsDNSSEC { bestResponse = response bestRtt = rtt hasBestResponse = true hasDNSSECResponse = true usedDNSServer = fastestServer // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == fastestServer { usedDNSSECServer = fastestServer break } } logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt) } else { // 没有带DNSSEC的响应时,保存成功响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = fastestServer // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == fastestServer { usedDNSSECServer = fastestServer break } } logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt) } } else if response.Rcode == dns.RcodeNameError { // 处理NXDOMAIN响应 bestResponse = response bestRtt = rtt hasBestResponse = true usedDNSServer = fastestServer logger.Debug("找到NXDOMAIN响应", "domain", domain, "server", fastestServer, "rtt", rtt) } // 保存为备选响应 if !hasBackup { backupResponse = response backupRtt = rtt hasBackup = true } } } else { // 更新服务器统计信息(失败) s.updateServerStats(fastestServer, false, 0) } } default: // 默认使用并行请求模式 - 添加超时处理和快速响应返回 responses := make(chan serverResponse, len(selectedUpstreamDNS)) var wg sync.WaitGroup // 向所有上游服务器并行发送请求 for _, upstream := range selectedUpstreamDNS { wg.Add(1) go func(server string) { defer wg.Done() // 发送请求并获取响应 response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server)) responses <- serverResponse{response, rtt, server, err} }(upstream) } // 等待所有请求完成 go func() { wg.Wait() close(responses) }() // 处理所有响应,实现快速响应返回 for resp := range responses { if resp.error == nil && resp.response != nil { // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(resp.response) // 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名 // 但如果域名匹配不验证DNSSEC的模式,则跳过验证 if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC { // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(resp.response) // 设置AD标志(Authenticated Data) resp.response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } else if noDNSSEC { // 对于不验证DNSSEC的域名,始终设置AD标志为false resp.response.AuthenticatedData = false } // 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应 if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError { if resp.response.Rcode == dns.RcodeSuccess { // 优先选择带有DNSSEC记录的响应 if containsDNSSEC { bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true hasDNSSECResponse = true usedDNSServer = resp.server // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == resp.server { usedDNSSECServer = resp.server break } } logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:找到带DNSSEC的响应,立即返回 continue } else if !hasBestResponse { // 没有带DNSSEC的响应时,保存第一个成功响应 bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server // 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer for _, dnssecServer := range dnssecServers { if dnssecServer == resp.server { usedDNSSECServer = resp.server break } } logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:第一次找到成功响应,立即返回 continue } } else if resp.response.Rcode == dns.RcodeNameError { // 处理NXDOMAIN响应 // 如果还没有最佳响应,或者最佳响应也是NXDOMAIN,优先选择更快的NXDOMAIN响应 if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError { // 如果还没有最佳响应,或者当前响应更快,更新最佳响应 if !hasBestResponse || resp.rtt < bestRtt { bestResponse = resp.response bestRtt = resp.rtt hasBestResponse = true usedDNSServer = resp.server logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt) // 快速返回:找到NXDOMAIN响应,立即返回 continue } } } // 保存为备选响应 if !hasBackup { backupResponse = resp.response backupRtt = resp.rtt hasBackup = true } } } } } // 2. 当启用DNSSEC且没有找到带DNSSEC的响应时,向DNSSEC专用服务器发送请求 // 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains,则不使用DNSSEC专用服务器,只使用指定的DNS服务器 if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC { logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain) // 增加DNSSEC查询计数 s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ }) // 无论查询模式是什么,DNSSEC验证都只使用加权随机选择一个服务器 selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers) if selectedDnssecServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{response, rtt, err} }() var response *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan response, rtt, err = result.response, result.rtt, result.err if err == nil && response != nil { // 更新服务器统计信息 s.updateServerStats(selectedDnssecServer, true, rtt) // 检查是否包含DNSSEC记录 containsDNSSEC := s.hasDNSSECRecords(response) if response.Rcode == dns.RcodeSuccess { // 无论响应是否包含DNSSEC记录,只要使用了DNSSEC专用服务器,就设置usedDNSSECServer usedDNSSECServer = selectedDnssecServer // 验证DNSSEC记录 signatureValid := s.verifyDNSSEC(response) // 设置AD标志(Authenticated Data) response.AuthenticatedData = signatureValid if signatureValid { // 更新DNSSEC验证成功计数 s.updateStats(func(stats *Stats) { stats.DNSSECSuccess++ }) } else { // 更新DNSSEC验证失败计数 s.updateStats(func(stats *Stats) { stats.DNSSECFailed++ }) } // 优先使用DNSSEC专用服务器的响应,尤其是带有DNSSEC记录的 if containsDNSSEC { bestResponse = response bestRtt = rtt hasBestResponse = true hasDNSSECResponse = true logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应,优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt) } // 注意:如果DNSSEC专用服务器返回的响应不包含DNSSEC记录, // 我们不会覆盖之前从upstreamDNS获取的响应, // 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求 // 更新备选响应 if !hasBackup { backupResponse = response backupRtt = rtt hasBackup = true } } } else { // 更新服务器统计信息(失败) s.updateServerStats(selectedDnssecServer, false, 0) } } } // 3. 返回最佳响应 if hasBestResponse { // 检查最佳响应是否包含DNSSEC记录 bestHasDNSSEC := s.hasDNSSECRecords(bestResponse) // 如果启用了DNSSEC且最佳响应不包含DNSSEC记录,尝试使用本地解析(使用upstreamDNS服务器) // 但如果域名匹配了domainSpecificDNS配置,则不执行此逻辑,只使用指定的DNS服务器 if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched { logger.Debug("最佳响应不包含DNSSEC记录,尝试使用本地解析(upstreamDNS)", "domain", domain) // 选择一个upstreamDNS服务器进行解析(使用加权随机算法) localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS) if localServer != "" { // 使用带超时的方式执行Exchange resultChan := make(chan struct { response *dns.Msg rtt time.Duration err error }, 1) go func() { resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer)) resultChan <- struct { response *dns.Msg rtt time.Duration err error }{resp, r, e} }() var localResponse *dns.Msg var rtt time.Duration var err error // 直接获取结果,不使用上下文超时 result := <-resultChan localResponse, rtt, err = result.response, result.rtt, result.err if err == nil && localResponse != nil { // 更新服务器统计信息 s.updateServerStats(localServer, true, rtt) // 检查是否包含DNSSEC记录 localHasDNSSEC := s.hasDNSSECRecords(localResponse) // 验证DNSSEC记录(如果存在),但不影响最终响应 if localHasDNSSEC { signatureValid := s.verifyDNSSEC(localResponse) localResponse.AuthenticatedData = signatureValid if signatureValid { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECSuccess++ }) } else { s.updateStats(func(stats *Stats) { stats.DNSSECQueries++ stats.DNSSECFailed++ }) } } // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新域名的DNSSEC状态 s.updateDomainDNSSECStatus(domain, localHasDNSSEC) s.updateStats(func(stats *Stats) { stats.Allowed++ }) logger.Debug("使用本地解析结果(upstreamDNS)", "domain", domain, "server", localServer, "rtt", rtt) return localResponse, rtt, localServer, "" } else { // 更新服务器统计信息(失败) s.updateServerStats(localServer, false, 0) } } } // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新域名的DNSSEC状态 if bestHasDNSSEC { s.updateDomainDNSSECStatus(domain, true) } else { s.updateDomainDNSSECStatus(domain, false) } s.updateStats(func(stats *Stats) { stats.Allowed++ }) return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer } // 如果有备选响应,返回该响应 if hasBackup { logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain) // 记录解析域名统计 s.updateResolvedDomainStats(domain) // 更新统计信息 s.updateStats(func(stats *Stats) { stats.Allowed++ }) return backupResponse, backupRtt, "", "" } // 所有上游服务器都失败,返回服务器失败错误 response := new(dns.Msg) response.SetReply(r) response.SetRcode(r, dns.RcodeServerFailure) logger.Error("DNS查询失败", "domain", domain) s.updateStats(func(stats *Stats) { stats.Errors++ }) return response, 0, "", "" } // forwardDNSRequest 转发DNS请求到上游服务器 func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) { response, _, _, _ := s.forwardDNSRequestWithCache(r, domain) w.WriteMsg(response) } // updateBlockedDomainStats 更新被屏蔽域名统计 func (s *Server) updateBlockedDomainStats(domain string) { // 更新被屏蔽域名计数 s.blockedDomainsMutex.Lock() defer s.blockedDomainsMutex.Unlock() if entry, exists := s.blockedDomains[domain]; exists { entry.Count++ entry.LastSeen = time.Now() } else { s.blockedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now(), } } // 更新统计数据 now := time.Now() // 更新小时统计 hourKey := now.Format("2006-01-02-15") s.hourlyStatsMutex.Lock() s.hourlyStats[hourKey]++ s.hourlyStatsMutex.Unlock() // 更新每日统计 dayKey := now.Format("2006-01-02") s.dailyStatsMutex.Lock() s.dailyStats[dayKey]++ s.dailyStatsMutex.Unlock() // 更新每月统计 monthKey := now.Format("2006-01") s.monthlyStatsMutex.Lock() s.monthlyStats[monthKey]++ s.monthlyStatsMutex.Unlock() } // updateClientStats 更新客户端统计 func (s *Server) updateClientStats(ip string) { s.clientStatsMutex.Lock() defer s.clientStatsMutex.Unlock() if entry, exists := s.clientStats[ip]; exists { entry.Count++ entry.LastSeen = time.Now() } else { s.clientStats[ip] = &ClientStats{ IP: ip, Count: 1, LastSeen: time.Now(), } } } // hasDNSSECRecords 检查响应是否包含DNSSEC记录 func (s *Server) hasDNSSECRecords(response *dns.Msg) bool { // 检查响应中是否包含DNSSEC相关记录(DNSKEY、RRSIG、DS、NSEC、NSEC3等) for _, rr := range response.Answer { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } for _, rr := range response.Ns { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } for _, rr := range response.Extra { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } return false } // verifyDNSSEC 验证DNSSEC签名 func (s *Server) verifyDNSSEC(response *dns.Msg) bool { // 提取DNSKEY和RRSIG记录 dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY rrsigs := make([]*dns.RRSIG, 0) // 从响应中提取所有DNSKEY和RRSIG记录 for _, rr := range response.Answer { if dnskey, ok := rr.(*dns.DNSKEY); ok { tag := dnskey.KeyTag() dnskeys[tag] = dnskey } else if rrsig, ok := rr.(*dns.RRSIG); ok { rrsigs = append(rrsigs, rrsig) } } for _, rr := range response.Ns { if dnskey, ok := rr.(*dns.DNSKEY); ok { tag := dnskey.KeyTag() dnskeys[tag] = dnskey } else if rrsig, ok := rr.(*dns.RRSIG); ok { rrsigs = append(rrsigs, rrsig) } } for _, rr := range response.Extra { if dnskey, ok := rr.(*dns.DNSKEY); ok { tag := dnskey.KeyTag() dnskeys[tag] = dnskey } else if rrsig, ok := rr.(*dns.RRSIG); ok { rrsigs = append(rrsigs, rrsig) } } // 如果没有RRSIG记录,验证失败 if len(rrsigs) == 0 { return false } // 验证所有RRSIG记录 signatureValid := true // 用于记录已经警告过的DNSKEY tag,避免重复警告 warnedKeyTags := make(map[uint16]bool) for _, rrsig := range rrsigs { // 查找对应的DNSKEY dnskey, exists := dnskeys[rrsig.KeyTag] if !exists { // 仅当该key_tag尚未警告过时才记录警告 if !warnedKeyTags[rrsig.KeyTag] { logger.Warn("DNSSEC验证失败:找不到对应的DNSKEY", "key_tag", rrsig.KeyTag) warnedKeyTags[rrsig.KeyTag] = true } signatureValid = false continue } // 收集需要验证的记录集 rrset := make([]dns.RR, 0) for _, rr := range response.Answer { if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered { rrset = append(rrset, rr) } } for _, rr := range response.Ns { if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered { rrset = append(rrset, rr) } } // 验证签名 if len(rrset) > 0 { err := rrsig.Verify(dnskey, rrset) if err != nil { logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag) signatureValid = false } else { logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag) } } } return signatureValid } // updateDomainDNSSECStatus 更新域名的DNSSEC状态 func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) { // 确保域名是小写 domain = strings.ToLower(domain) // 更新域名的DNSSEC状态 s.resolvedDomainsMutex.Lock() defer s.resolvedDomainsMutex.Unlock() // 更新resolvedDomains中的DNSSEC状态 if entry, exists := s.resolvedDomains[domain]; exists { entry.DNSSEC = dnssec } else { s.resolvedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now(), DNSSEC: dnssec, } } // 更新domainDNSSECStatus映射 s.domainDNSSECStatus[domain] = dnssec } // updateResolvedDomainStats 更新解析域名统计 func (s *Server) updateResolvedDomainStats(domain string) { s.resolvedDomainsMutex.Lock() defer s.resolvedDomainsMutex.Unlock() if entry, exists := s.resolvedDomains[domain]; exists { entry.Count++ entry.LastSeen = time.Now() } else { s.resolvedDomains[domain] = &BlockedDomain{ Domain: domain, Count: 1, LastSeen: time.Now(), DNSSEC: false, } } } // getServerStats 获取服务器统计信息,如果不存在则创建 func (s *Server) getServerStats(server string) *ServerStats { s.serverStatsMutex.RLock() stats, exists := s.serverStats[server] s.serverStatsMutex.RUnlock() if !exists { // 创建新的服务器统计信息 stats = &ServerStats{ SuccessCount: 0, FailureCount: 0, LastResponse: time.Now(), ResponseTime: 0, ConnectionSpeed: 0, } // 加锁更新服务器统计信息 s.serverStatsMutex.Lock() s.serverStats[server] = stats s.serverStatsMutex.Unlock() } return stats } // updateServerStats 更新服务器统计信息 func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) { stats := s.getServerStats(server) s.serverStatsMutex.Lock() defer s.serverStatsMutex.Unlock() // 更新统计信息 stats.LastResponse = time.Now() if success { stats.SuccessCount++ // 更新平均响应时间(简单移动平均) // 将所有值转换为纳秒进行计算,然后再转换回Duration if stats.SuccessCount == 1 { // 第一次成功,直接使用当前响应时间 stats.ResponseTime = rtt } else { // 使用纳秒进行计算以避免类型不匹配 prevTotal := stats.ResponseTime.Nanoseconds() * (stats.SuccessCount - 1) newTotal := prevTotal + rtt.Nanoseconds() stats.ResponseTime = time.Duration(newTotal / stats.SuccessCount) } } else { stats.FailureCount++ } } // selectWeightedRandomServer 加权随机选择服务器 func (s *Server) selectWeightedRandomServer(servers []string) string { if len(servers) == 0 { return "" } if len(servers) == 1 { return servers[0] } // 计算每个服务器的权重 type serverWeight struct { server string weight int64 } var totalWeight int64 weights := make([]serverWeight, 0, len(servers)) // 获取所有服务器的平均响应时间,用于归一化 var totalResponseTime time.Duration validServers := 0 for _, server := range servers { stats := s.getServerStats(server) if stats.ResponseTime > 0 { totalResponseTime += stats.ResponseTime validServers++ } } // 计算平均响应时间基准值 var avgResponseTime time.Duration if validServers > 0 { avgResponseTime = totalResponseTime / time.Duration(validServers) } else { avgResponseTime = 1 * time.Second // 默认基准值 } for _, server := range servers { stats := s.getServerStats(server) // 计算基础权重:成功次数 - 失败次数 * 2(失败权重更高) // 确保权重至少为1 baseWeight := stats.SuccessCount - stats.FailureCount*2 if baseWeight < 1 { baseWeight = 1 } // 计算响应时间调整因子:响应时间越短,因子越高 // 如果没有响应时间数据,使用默认值1 var responseFactor float64 = 1.0 if stats.ResponseTime > 0 { // 使用平均响应时间作为基准,计算调整因子 // 响应时间越短,因子越高,最高为2.0,最低为0.5 responseFactor = float64(avgResponseTime) / float64(stats.ResponseTime) // 限制调整因子的范围,避免权重波动过大 if responseFactor > 2.0 { responseFactor = 2.0 } else if responseFactor < 0.5 { responseFactor = 0.5 } } // 综合计算最终权重,四舍五入到整数 finalWeight := int64(float64(baseWeight) * responseFactor) // 确保最终权重至少为1 if finalWeight < 1 { finalWeight = 1 } weights = append(weights, serverWeight{server, finalWeight}) totalWeight += finalWeight } // 随机选择一个权重 random := time.Now().UnixNano() % totalWeight if random < 0 { random += totalWeight } // 选择对应的服务器 var currentWeight int64 for _, sw := range weights { currentWeight += sw.weight if random < currentWeight { return sw.server } } // 兜底返回第一个服务器 return servers[0] } // measureServerSpeed 测量服务器TCP连接速度 func (s *Server) measureServerSpeed(server string) time.Duration { // 提取服务器地址和端口 addr := server if !strings.Contains(server, ":") { addr = server + ":53" } // 测量TCP连接时间 startTime := time.Now() conn, err := net.DialTimeout("tcp", addr, 2*time.Second) if err != nil { // 连接失败,返回最大持续时间 return 2 * time.Second } defer conn.Close() // 计算连接建立时间 connTime := time.Since(startTime) // 更新服务器连接速度 stats := s.getServerStats(server) s.serverStatsMutex.Lock() // 使用指数移动平均更新连接速度 stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4 s.serverStatsMutex.Unlock() return connTime } // selectFastestServer 选择连接速度最快的服务器 func (s *Server) selectFastestServer(servers []string) string { if len(servers) == 0 { return "" } if len(servers) == 1 { return servers[0] } // 并行测量所有服务器的速度 type speedResult struct { server string speed time.Duration } results := make(chan speedResult, len(servers)) var wg sync.WaitGroup for _, server := range servers { wg.Add(1) go func(srv string) { defer wg.Done() speed := s.measureServerSpeed(srv) results <- speedResult{srv, speed} }(server) } // 等待所有测量完成 go func() { wg.Wait() close(results) }() // 找出最快的服务器 var fastestServer string var fastestSpeed time.Duration = 2 * time.Second for result := range results { if result.speed < fastestSpeed { fastestSpeed = result.speed fastestServer = result.server } } // 如果没有找到最快服务器(理论上不会发生),返回第一个服务器 if fastestServer == "" { fastestServer = servers[0] } return fastestServer } // updateStats 更新统计信息 func (s *Server) updateStats(update func(*Stats)) { s.statsMutex.Lock() defer s.statsMutex.Unlock() update(s.stats) } // addQueryLog 添加查询日志 func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) { // 获取IP地理位置 location := s.getIpGeolocation(clientIP) // 创建日志记录 log := QueryLog{ Timestamp: time.Now(), ClientIP: clientIP, Location: location, Domain: domain, QueryType: queryType, ResponseTime: responseTime, Result: result, BlockRule: blockRule, BlockType: blockType, FromCache: fromCache, DNSSEC: dnssec, EDNS: edns, DNSServer: dnsServer, DNSSECServer: dnssecServer, Answers: answers, ResponseCode: responseCode, } // 添加到日志列表 s.queryLogsMutex.Lock() defer s.queryLogsMutex.Unlock() // 插入到列表开头 s.queryLogs = append([]QueryLog{log}, s.queryLogs...) // 限制日志数量 if len(s.queryLogs) > s.maxQueryLogs { s.queryLogs = s.queryLogs[:s.maxQueryLogs] } } // GetStartTime 获取服务器启动时间 func (s *Server) GetStartTime() time.Time { return s.startTime } // GetStats 获取DNS服务器统计信息 func (s *Server) GetStats() *Stats { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 复制查询类型统计 queryTypesCopy := make(map[string]int64) for k, v := range s.stats.QueryTypes { queryTypesCopy[k] = v } // 复制来源IP统计 sourceIPsCopy := make(map[string]bool) for ip := range s.stats.SourceIPs { sourceIPsCopy[ip] = true } // 返回统计信息的副本 return &Stats{ Queries: s.stats.Queries, Blocked: s.stats.Blocked, Allowed: s.stats.Allowed, Errors: s.stats.Errors, LastQuery: s.stats.LastQuery, AvgResponseTime: s.stats.AvgResponseTime, TotalResponseTime: s.stats.TotalResponseTime, QueryTypes: queryTypesCopy, SourceIPs: sourceIPsCopy, CpuUsage: s.stats.CpuUsage, DNSSECQueries: s.stats.DNSSECQueries, DNSSECSuccess: s.stats.DNSSECSuccess, DNSSECFailed: s.stats.DNSSECFailed, DNSSECEnabled: s.stats.DNSSECEnabled, } } // GetQueryLogs 获取查询日志 func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog { s.queryLogsMutex.RLock() defer s.queryLogsMutex.RUnlock() // 确保偏移量和限制值合理 if offset < 0 { offset = 0 } if limit <= 0 { limit = 100 // 默认返回100条日志 } // 创建日志副本用于过滤和排序 var logsCopy []QueryLog // 先过滤日志 for _, log := range s.queryLogs { // 应用结果过滤 if resultFilter != "" && log.Result != resultFilter { continue } // 应用搜索过滤 if searchTerm != "" { // 搜索域名或客户端IP if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) { continue } } logsCopy = append(logsCopy, log) } // 排序日志 if sortField != "" { sort.Slice(logsCopy, func(i, j int) bool { var a, b interface{} switch sortField { case "time": a = logsCopy[i].Timestamp b = logsCopy[j].Timestamp case "clientIp": a = logsCopy[i].ClientIP b = logsCopy[j].ClientIP case "domain": a = logsCopy[i].Domain b = logsCopy[j].Domain case "responseTime": a = logsCopy[i].ResponseTime b = logsCopy[j].ResponseTime case "blockRule": a = logsCopy[i].BlockRule b = logsCopy[j].BlockRule default: // 默认按时间排序 a = logsCopy[i].Timestamp b = logsCopy[j].Timestamp } // 根据排序方向比较 if sortDirection == "asc" { return compareValues(a, b) < 0 } return compareValues(a, b) > 0 }) } // 计算返回范围 start := offset end := offset + limit if end > len(logsCopy) { end = len(logsCopy) } if start >= len(logsCopy) { return []QueryLog{} // 没有数据,返回空切片 } return logsCopy[start:end] } // compareValues 比较两个值 func compareValues(a, b interface{}) int { switch v1 := a.(type) { case time.Time: v2 := b.(time.Time) if v1.Before(v2) { return -1 } if v1.After(v2) { return 1 } return 0 case string: v2 := b.(string) if v1 < v2 { return -1 } if v1 > v2 { return 1 } return 0 case int64: v2 := b.(int64) if v1 < v2 { return -1 } if v1 > v2 { return 1 } return 0 default: return 0 } } // GetQueryLogsCount 获取查询日志总数 func (s *Server) GetQueryLogsCount() int { s.queryLogsMutex.RLock() defer s.queryLogsMutex.RUnlock() return len(s.queryLogs) } // GetQueryStats 获取查询统计信息 func (s *Server) GetQueryStats() map[string]interface{} { s.statsMutex.Lock() defer s.statsMutex.Unlock() // 计算统计数据 return map[string]interface{}{ "totalQueries": s.stats.Queries, "blockedQueries": s.stats.Blocked, "allowedQueries": s.stats.Allowed, "errorQueries": s.stats.Errors, "avgResponseTime": s.stats.AvgResponseTime, "activeIPs": len(s.stats.SourceIPs), } } // GetTopBlockedDomains 获取TOP屏蔽域名列表 func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按计数排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetTopResolvedDomains 获取TOP解析域名 func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain { s.resolvedDomainsMutex.RLock() defer s.resolvedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.resolvedDomains)) for _, entry := range s.resolvedDomains { domains = append(domains, *entry) } // 按数量排序 sort.Slice(domains, func(i, j int) bool { return domains[i].Count > domains[j].Count }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetRecentBlockedDomains 获取最近屏蔽的域名列表 func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain { s.blockedDomainsMutex.RLock() defer s.blockedDomainsMutex.RUnlock() // 转换为切片 domains := make([]BlockedDomain, 0, len(s.blockedDomains)) for _, entry := range s.blockedDomains { domains = append(domains, *entry) } // 按时间排序 sort.Slice(domains, func(i, j int) bool { return domains[i].LastSeen.After(domains[j].LastSeen) }) // 返回限制数量 if len(domains) > limit { return domains[:limit] } return domains } // GetTopClients 获取TOP客户端列表 func (s *Server) GetTopClients(limit int) []ClientStats { s.clientStatsMutex.RLock() defer s.clientStatsMutex.RUnlock() // 转换为切片 clients := make([]ClientStats, 0, len(s.clientStats)) for _, entry := range s.clientStats { clients = append(clients, *entry) } // 按请求次数排序 sort.Slice(clients, func(i, j int) bool { return clients[i].Count > clients[j].Count }) // 返回限制数量 if len(clients) > limit { return clients[:limit] } return clients } // GetHourlyStats 获取每小时统计数据 func (s *Server) GetHourlyStats() map[string]int64 { s.hourlyStatsMutex.RLock() defer s.hourlyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.hourlyStats { result[k] = v } return result } // GetDailyStats 获取每日统计数据 func (s *Server) GetDailyStats() map[string]int64 { s.dailyStatsMutex.RLock() defer s.dailyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.dailyStats { result[k] = v } return result } // GetMonthlyStats 获取每月统计数据 func (s *Server) GetMonthlyStats() map[string]int64 { s.monthlyStatsMutex.RLock() defer s.monthlyStatsMutex.RUnlock() // 返回副本 result := make(map[string]int64) for k, v := range s.monthlyStats { result[k] = v } return result } // isPrivateIP 检测IP地址是否为内网IP func isPrivateIP(ip string) bool { // 解析IP地址 parsedIP := net.ParseIP(ip) if parsedIP == nil { return false } // 检查IPv4内网地址 if ipv4 := parsedIP.To4(); ipv4 != nil { // 10.0.0.0/8 if ipv4[0] == 10 { return true } // 172.16.0.0/12 if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) { return true } // 192.168.0.0/16 if ipv4[0] == 192 && ipv4[1] == 168 { return true } // 127.0.0.0/8 (localhost) if ipv4[0] == 127 { return true } // 169.254.0.0/16 (链路本地地址) if ipv4[0] == 169 && ipv4[1] == 254 { return true } return false } // 检查IPv6内网地址 // ::1/128 (localhost) if parsedIP.IsLoopback() { return true } // fc00::/7 (唯一本地地址) if parsedIP[0]&0xfc == 0xfc { return true } // fe80::/10 (链路本地地址) if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 { return true } return false } // getIpGeolocation 获取IP地址的地理位置信息 func (s *Server) getIpGeolocation(ip string) string { // 检查IP是否为本地或内网地址 if isPrivateIP(ip) { return "内网 内网" } // 先检查缓存 s.ipGeolocationCacheMutex.RLock() geo, exists := s.ipGeolocationCache[ip] s.ipGeolocationCacheMutex.RUnlock() // 如果缓存存在且未过期,直接返回 if exists && time.Now().Before(geo.Expiry) { return fmt.Sprintf("%s %s", geo.Country, geo.City) } // 缓存不存在或已过期,从API获取 geoInfo, err := s.fetchIpGeolocationFromAPI(ip) if err != nil { logger.Error("获取IP地理位置失败", "ip", ip, "error", err) return "未知 未知" } // 保存到缓存 s.ipGeolocationCacheMutex.Lock() s.ipGeolocationCache[ip] = &IPGeolocation{ Country: geoInfo["country"].(string), City: geoInfo["city"].(string), Expiry: time.Now().Add(s.ipGeolocationCacheTTL), } s.ipGeolocationCacheMutex.Unlock() // 返回格式化的地理位置 return fmt.Sprintf("%s %s", geoInfo["country"].(string), geoInfo["city"].(string)) } // fetchIpGeolocationFromAPI 从第三方API获取IP地理位置信息 func (s *Server) fetchIpGeolocationFromAPI(ip string) (map[string]interface{}, error) { // 使用ip-api.com获取IP地理位置信息 url := fmt.Sprintf("http://ip-api.com/json/%s?fields=country,city", ip) resp, err := http.Get(url) if err != nil { return nil, err } defer resp.Body.Close() // 读取响应内容 body, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } // 解析JSON响应 var result map[string]interface{} err = json.Unmarshal(body, &result) if err != nil { return nil, err } // 检查API返回状态 status, ok := result["status"].(string) if !ok || status != "success" { return nil, fmt.Errorf("API返回错误状态: %v", result) } // 确保国家和城市字段存在 if _, ok := result["country"]; !ok { result["country"] = "未知" } if _, ok := result["city"]; !ok { result["city"] = "未知" } return result, nil } // loadStatsData 从文件加载统计数据 func (s *Server) loadStatsData() { // 检查文件是否存在 data, err := ioutil.ReadFile("data/stats.json") if err != nil { if !os.IsNotExist(err) { logger.Error("读取统计数据文件失败", "error", err) } return } var statsData StatsData err = json.Unmarshal(data, &statsData) if err != nil { logger.Error("解析统计数据失败", "error", err) return } // 恢复统计数据 s.statsMutex.Lock() if statsData.Stats != nil { s.stats = statsData.Stats // 确保使用当前配置中的EnableDNSSEC值,覆盖从文件加载的值 s.stats.DNSSECEnabled = s.config.EnableDNSSEC } s.statsMutex.Unlock() s.blockedDomainsMutex.Lock() if statsData.BlockedDomains != nil { s.blockedDomains = statsData.BlockedDomains } s.blockedDomainsMutex.Unlock() s.resolvedDomainsMutex.Lock() if statsData.ResolvedDomains != nil { s.resolvedDomains = statsData.ResolvedDomains } s.resolvedDomainsMutex.Unlock() s.hourlyStatsMutex.Lock() if statsData.HourlyStats != nil { s.hourlyStats = statsData.HourlyStats } s.hourlyStatsMutex.Unlock() s.dailyStatsMutex.Lock() if statsData.DailyStats != nil { s.dailyStats = statsData.DailyStats } s.dailyStatsMutex.Unlock() s.monthlyStatsMutex.Lock() if statsData.MonthlyStats != nil { s.monthlyStats = statsData.MonthlyStats } s.monthlyStatsMutex.Unlock() // 加载客户端统计数据 s.clientStatsMutex.Lock() if statsData.ClientStats != nil { s.clientStats = statsData.ClientStats } s.clientStatsMutex.Unlock() logger.Info("统计数据加载成功") // 加载查询日志 s.loadQueryLogs() } // loadQueryLogs 从文件加载查询日志 func (s *Server) loadQueryLogs() { // 获取绝对路径 statsFilePath, err := filepath.Abs("data/stats.json") if err != nil { logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err) return } // 构建查询日志文件路径 queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json") // 检查文件是否存在 if _, err := os.Stat(queryLogPath); os.IsNotExist(err) { logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath) return } // 读取文件内容 data, err := ioutil.ReadFile(queryLogPath) if err != nil { logger.Error("读取查询日志文件失败", "error", err) return } // 解析数据 var logs []QueryLog err = json.Unmarshal(data, &logs) if err != nil { logger.Error("解析查询日志失败", "error", err) return } // 更新查询日志 s.queryLogsMutex.Lock() s.queryLogs = logs // 确保日志数量不超过限制 if len(s.queryLogs) > s.maxQueryLogs { s.queryLogs = s.queryLogs[:s.maxQueryLogs] } s.queryLogsMutex.Unlock() logger.Info("查询日志加载成功", "count", len(logs)) } // saveStatsData 保存统计数据到文件 func (s *Server) saveStatsData() { // 获取绝对路径以避免工作目录问题 statsFilePath, err := filepath.Abs("data/stats.json") if err != nil { logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err) return } // 创建数据目录 statsDir := filepath.Dir(statsFilePath) err = os.MkdirAll(statsDir, 0755) if err != nil { logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err) return } // 收集所有统计数据 statsData := &StatsData{ Stats: s.GetStats(), LastSaved: time.Now(), } // 复制域名数据 s.blockedDomainsMutex.RLock() statsData.BlockedDomains = make(map[string]*BlockedDomain) for k, v := range s.blockedDomains { statsData.BlockedDomains[k] = v } s.blockedDomainsMutex.RUnlock() s.resolvedDomainsMutex.RLock() statsData.ResolvedDomains = make(map[string]*BlockedDomain) for k, v := range s.resolvedDomains { statsData.ResolvedDomains[k] = v } s.resolvedDomainsMutex.RUnlock() s.hourlyStatsMutex.RLock() statsData.HourlyStats = make(map[string]int64) for k, v := range s.hourlyStats { statsData.HourlyStats[k] = v } s.hourlyStatsMutex.RUnlock() s.dailyStatsMutex.RLock() statsData.DailyStats = make(map[string]int64) for k, v := range s.dailyStats { statsData.DailyStats[k] = v } s.dailyStatsMutex.RUnlock() s.monthlyStatsMutex.RLock() statsData.MonthlyStats = make(map[string]int64) for k, v := range s.monthlyStats { statsData.MonthlyStats[k] = v } s.monthlyStatsMutex.RUnlock() // 复制客户端统计数据 s.clientStatsMutex.RLock() statsData.ClientStats = make(map[string]*ClientStats) for k, v := range s.clientStats { statsData.ClientStats[k] = v } s.clientStatsMutex.RUnlock() // 序列化数据 jsonData, err := json.MarshalIndent(statsData, "", " ") if err != nil { logger.Error("序列化统计数据失败", "error", err) return } // 写入文件 err = os.WriteFile(statsFilePath, jsonData, 0644) if err != nil { logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err) return } logger.Info("统计数据保存成功", "file", statsFilePath) // 保存查询日志到文件 s.saveQueryLogs(statsDir) } // saveQueryLogs 保存查询日志到文件 func (s *Server) saveQueryLogs(dataDir string) { // 构建查询日志文件路径 queryLogPath := filepath.Join(dataDir, "querylog.json") // 获取查询日志数据 s.queryLogsMutex.RLock() logsCopy := make([]QueryLog, len(s.queryLogs)) copy(logsCopy, s.queryLogs) s.queryLogsMutex.RUnlock() // 序列化数据 jsonData, err := json.MarshalIndent(logsCopy, "", " ") if err != nil { logger.Error("序列化查询日志失败", "error", err) return } // 写入文件 err = os.WriteFile(queryLogPath, jsonData, 0644) if err != nil { logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err) return } logger.Info("查询日志保存成功", "file", queryLogPath) } // startCpuUsageMonitor 启动CPU使用率监控 func (s *Server) startCpuUsageMonitor() { ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率 defer ticker.Stop() // 初始化 var memStats runtime.MemStats runtime.ReadMemStats(&memStats) // 存储上一次的CPU时间统计 var prevIdle, prevTotal uint64 for { select { case <-ticker.C: // 获取真实的系统级CPU使用率 cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal) if err != nil { // 如果获取失败,使用默认值 cpuUsage = 0.0 logger.Error("获取系统CPU使用率失败", "error", err) } s.updateStats(func(stats *Stats) { stats.CpuUsage = cpuUsage }) case <-s.ctx.Done(): return } } } // getSystemCpuUsage 获取系统CPU使用率 func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) { // 读取/proc/stat文件获取CPU统计信息 file, err := os.Open("/proc/stat") if err != nil { return 0, err } defer file.Close() var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64 _, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d", &cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal) if err != nil { return 0, err } // 计算总的CPU时间 total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal idle := cpuIdle + cpuIowait // 第一次调用时,只初始化值,不计算使用率 if *prevTotal == 0 || *prevIdle == 0 { *prevIdle = idle *prevTotal = total return 0, nil } // 计算CPU使用率 idleDelta := idle - *prevIdle totalDelta := total - *prevTotal utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100 // 更新上一次的值 *prevIdle = idle *prevTotal = total return utilization, nil } // startAutoSave 启动自动保存功能 func (s *Server) startAutoSave() { if s.config.SaveInterval <= 0 { return } // 初始化定时器 s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second) defer s.saveTicker.Stop() logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json") // 定期保存数据 for { select { case <-s.saveTicker.C: s.saveStatsData() case <-s.saveDone: return } } }