Files
dns-server/dns/server.go.orig
T
2026-04-04 12:25:49 +08:00

3294 lines
96 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"math"
"math/rand"
"net"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"dns-server/config"
"dns-server/gfw"
"dns-server/log"
"dns-server/logger"
"dns-server/shield"
"dns-server/threat"
"github.com/miekg/dns"
)
// 确保DNS服务器地址包含端口号,默认添加53端口
func normalizeDNSServerAddress(address string) string {
// 检查地址是否已经包含端口号
if _, _, err := net.SplitHostPort(address); err != nil {
// 如果没有端口号,添加默认的53端口
return net.JoinHostPort(address, "53")
}
// 已经有端口号,直接返回
return address
}
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen int64
DNSSEC bool // 是否使用了DNSSEC
}
// ClientStats 客户端统计
type ClientStats struct {
IP string
Count int64
LastSeen int64
}
// DNSAnswer DNS解析记录
type DNSAnswer struct {
Type string `json:"type"` // 记录类型
Value string `json:"value"` // 记录值
TTL uint32 `json:"ttl"` // 生存时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time `json:"timestamp"` // 查询时间
ClientIP string `json:"clientIP"` // 客户端IP
Location string `json:"location"` // IP地理位置(国家 城市)
Domain string `json:"domain"` // 查询域名
QueryType string `json:"queryType"` // 查询类型
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
Result string `json:"result"` // 查询结果(allowed, blocked, error
BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽)
BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽)
FromCache bool `json:"fromCache"` // 是否来自缓存
DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC
EDNS bool `json:"edns"` // 是否使用了EDNS
DNSServer string `json:"dnsServer"` // 使用的DNS服务器
DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器
Answers []DNSAnswer `json:"answers"` // 解析记录
ResponseCode int `json:"responseCode"` // DNS响应代码
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
ClientStats map[string]*ClientStats `json:"clientStats"`
HourlyStats map[string]int64 `json:"hourlyStats"`
DailyStats map[string]int64 `json:"dailyStats"`
MonthlyStats map[string]int64 `json:"monthlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// ServerStats 服务器统计信息
type ServerStats struct {
SuccessCount int64 // 成功查询次数
FailureCount int64 // 失败查询次数
LastResponse time.Time // 最后响应时间
ResponseTime time.Duration // 平均响应时间
ConnectionSpeed time.Duration // TCP连接速度
}
// Server DNS服务器
type Server struct {
globalConfig *config.Config
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
gfwConfig *config.GFWListConfig
gfwManager *gfw.GFWListManager
server *dns.Server
tcpServer *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
clientStatsMutex sync.RWMutex
clientStats map[string]*ClientStats // 用于记录客户端统计
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
dailyStatsMutex sync.RWMutex
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
// 新日志系统
logManager *log.LogManager // 日志管理器
archiveManager *log.ArchiveManager // 归档管理器
archiveQueryEngine *log.ArchiveQueryEngine // 归档查询引擎
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
// 域名DNSSEC状态映射表
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
domainDNSSECStatusMutex sync.RWMutex // 保护域名DNSSEC状态映射的互斥锁
// 上游服务器状态跟踪
serverStats map[string]*ServerStats // 服务器地址到状态的映射
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
// DNSSEC专用服务器映射,用于快速查找
dnssecServerMap map[string]bool // DNSSEC专用服务器地址到布尔值的映射
// DNS客户端实例池,用于并行查询
clientPool sync.Pool // 存储*dns.Client实例
// 威胁检测相关
threatEngine *threat.ThreatEngine
alertManager *threat.AlertManager
dbManager *threat.ThreatDatabaseManager
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
AvgResponseTime float64 // 平均响应时间(ms)
TotalResponseTime int64 // 总响应时间
QueryTypes map[string]int64 // 查询类型统计
SourceIPs map[string]bool // 活跃来源IP
CpuUsage float64 // CPU使用率(%)
DNSSECQueries int64 // DNSSEC查询总数
DNSSECSuccess int64 // DNSSEC验证成功数
DNSSECFailed int64 // DNSSEC验证失败数
DNSSECEnabled bool // 是否启用了DNSSEC
}
// NewServer 创建DNS服务器实例
func NewServer(globalConfig *config.Config, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值(分钟)
cacheTTL := time.Duration(globalConfig.DNS.CacheTTL) * time.Minute
// 保存间隔(秒)
saveInterval := time.Duration(globalConfig.DNS.SaveInterval) * time.Second
// 最大和最小缓存TTL(分钟)
maxCacheTTL := time.Duration(globalConfig.DNS.MaxCacheTTL) * time.Minute
minCacheTTL := time.Duration(globalConfig.DNS.MinCacheTTL) * time.Minute
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.DNS,
shieldConfig: &globalConfig.Shield,
shieldManager: shieldManager,
gfwConfig: &globalConfig.GFWList,
gfwManager: gfwManager,
resolver: &dns.Client{
Net: "udp",
UDPSize: 4096, // 增加UDP缓冲区大小,支持更大的DNSSEC响应
},
ctx: ctx,
cancel: cancel,
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
AvgResponseTime: 0,
TotalResponseTime: 0,
QueryTypes: make(map[string]int64),
SourceIPs: make(map[string]bool),
CpuUsage: 0,
DNSSECQueries: 0,
DNSSECSuccess: 0,
DNSSECFailed: 0,
DNSSECEnabled: globalConfig.DNS.EnableDNSSEC,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
clientStats: make(map[string]*ClientStats),
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
// DNS 查询缓存初始化
DnsCache: NewDNSCache(cacheTTL, globalConfig.DNS.CacheMode, globalConfig.DNS.CacheSize, globalConfig.DNS.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
// 初始化域名 DNSSEC 状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
// 初始化 DNSSEC 专用服务器映射
dnssecServerMap: make(map[string]bool),
// 初始化 DNS 客户端实例池
clientPool: sync.Pool{
New: func() interface{} {
return &dns.Client{
Net: "udp",
UDPSize: 4096,
Timeout: 2 * time.Second, // 默认超时时间,会在使用时覆盖(2 秒是合理的 DNS 查询超时)
}
},
},
}
// 初始化新日志系统
logManager, err := log.NewLogManager(log.DefaultConfig())
if err != nil {
logger.Error("初始化日志管理器失败", "error", err)
} else {
server.logManager = logManager
logger.Info("新日志系统初始化成功", "ringBufferSize", log.DefaultConfig().RingBufferSize, "databasePath", log.DefaultConfig().DatabasePath)
}
// 初始化归档管理器
if globalConfig.QueryLog.ArchiveEnabled {
// 转换为 log.QueryLogConfig
logConfig := &log.QueryLogConfig{
Enabled: globalConfig.QueryLog.Enabled,
RingBufferSize: globalConfig.QueryLog.RingBufferSize,
DatabasePath: globalConfig.QueryLog.DatabasePath,
MaxDatabaseSizeMB: globalConfig.QueryLog.MaxDatabaseSizeMB,
EnableWAL: globalConfig.QueryLog.EnableWAL,
ArchiveEnabled: globalConfig.QueryLog.ArchiveEnabled,
ArchiveDir: globalConfig.QueryLog.ArchiveDir,
ArchivePrefix: globalConfig.QueryLog.ArchivePrefix,
CompressionLevel: globalConfig.QueryLog.CompressionLevel,
RetentionDays: globalConfig.QueryLog.RetentionDays,
RetentionMonths: globalConfig.QueryLog.RetentionMonths,
QueryTimeout: globalConfig.QueryLog.QueryTimeout,
EnableCache: globalConfig.QueryLog.EnableCache,
CacheTTL: globalConfig.QueryLog.CacheTTL,
}
archiveManager, err := log.NewArchiveManager(logConfig, globalConfig.QueryLog.DatabasePath)
if err != nil {
logger.Error("初始化归档管理器失败", "error", err)
} else {
server.archiveManager = archiveManager
logger.Info("归档管理器初始化成功", "archiveDir", globalConfig.QueryLog.ArchiveDir)
}
// 初始化归档查询引擎
if logManager != nil {
sqliteStore := logManager.GetSQLiteStore()
if sqliteStore != nil {
archiveQueryEngine, err := log.NewArchiveQueryEngine(sqliteStore, archiveManager, logConfig)
if err != nil {
logger.Error("初始化归档查询引擎失败", "error", err)
} else {
server.archiveQueryEngine = archiveQueryEngine
logger.Info("归档查询引擎初始化成功")
}
}
}
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// initThreatDetection 初始化威胁检测相关组件
func (s *Server) initThreatDetection() {
// 从全局配置中获取威胁检测配置
threatConfig := &s.globalConfig.Threat
// 创建告警管理器
s.alertManager = threat.NewAlertManager(threatConfig)
// 加载已保存的告警
s.alertManager.LoadAlerts()
// 创建威胁域名数据库管理器
s.dbManager = threat.NewThreatDatabaseManager(threatConfig.ThreatDatabasePath)
// 加载威胁域名数据库
s.dbManager.LoadDatabase()
// 启动文件监听,自动检测数据库变更
if err := s.dbManager.StartWatching(); err != nil {
logger.Warn("启动威胁域名数据库监听失败", "error", err)
}
// 创建威胁检测引擎
s.threatEngine = threat.NewThreatEngine(threatConfig, s.alertManager, s.dbManager)
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel
// 重新初始化saveDone通道
// 重置stopped标志
// 初始化威胁检测相关组件
s.initThreatDetection()
s.server = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 保存TCP服务器实例,以便在Stop方法中关闭
s.tcpServer = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动CPU使用率监控
go s.startCpuUsageMonitor()
// 启动自动保存功能
go s.startAutoSave()
// 更新DNSSEC专用服务器映射
s.updateDNSSECServerMap()
// 启动日志处理协程(已移除,新日志系统使用 SQLite 存储)
// go s.processLogs()
// 启动统计数据定期重置功能(每 24 小时)
go func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.resetStats()
case <-s.ctx.Done():
return
}
}
}()
// 启动归档监控和清理任务
if s.archiveManager != nil {
// 启动归档监控
s.archiveManager.StartWatching()
// 启动定期清理任务(每天执行)
go func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if deleted, err := s.archiveManager.CleanupOldArchives(); err != nil {
logger.Error("清理归档失败", "error", err)
} else if deleted > 0 {
logger.Info("清理归档完成", "deleted", deleted)
}
case <-s.ctx.Done():
return
}
}
}()
logger.Info("归档监控和清理任务已启动")
}
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port))
if err := s.tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// resetStats 重置统计数据
func (s *Server) resetStats() {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 只重置累计值,保留配置相关值
s.stats.TotalResponseTime = 0
s.stats.AvgResponseTime = 0
s.stats.Queries = 0
s.stats.Blocked = 0
s.stats.Allowed = 0
s.stats.Errors = 0
s.stats.DNSSECQueries = 0
s.stats.DNSSECSuccess = 0
s.stats.DNSSECFailed = 0
s.stats.QueryTypes = make(map[string]int64)
s.stats.SourceIPs = make(map[string]bool)
logger.Info("统计数据已重置")
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
// 标记服务器为已停止状态
// 停止威胁域名数据库文件监听
if s.dbManager != nil {
s.dbManager.StopWatching()
}
// 发送停止信号给保存协程
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
if s.tcpServer != nil {
s.tcpServer.Shutdown()
}
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
// 1. 初始化请求信息
reqInfo := s.initRequestInfo(w, r)
// 2. 检查基本请求条件
if earlyResponse := s.checkRequestConditions(w, r, startTime, reqInfo); earlyResponse {
return
}
// 3. 检查本地处理规则
if localHandled := s.handleLocalRules(w, r, startTime, reqInfo); localHandled {
return
}
// 4. 尝试从缓存获取响应
if cacheHandled := s.handleCacheResponse(w, r, startTime, reqInfo); cacheHandled {
return
}
// 5. 威胁检测
s.checkThreatDetection(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType)
// 6. 转发请求到上游服务器
s.handleUpstreamRequest(w, r, startTime, reqInfo)
}
// requestInfo 封装请求相关信息
type requestInfo struct {
sourceIP string
domain string
queryType string
qType uint16
queryAttempts []string
}
// initRequestInfo 初始化请求信息
func (s *Server) initRequestInfo(w dns.ResponseWriter, r *dns.Msg) *requestInfo {
// 获取来源IP
sourceIP := w.RemoteAddr().String()
// 提取IP地址部分,去掉端口
if strings.HasPrefix(sourceIP, "[") {
// IPv6地址格式: [::1]:53
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
sourceIP = sourceIP[1:idx] // 去掉方括号
}
} else {
// IPv4地址格式: 127.0.0.1:53
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
sourceIP = sourceIP[:idx]
}
}
// 更新来源IP统计和Queries计数器
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
stats.SourceIPs[sourceIP] = true
})
// 更新客户端统计
s.updateClientStats(sourceIP)
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// 获取查询类型
if t, ok := dns.TypeToString[r.Question[0].Qtype]; ok {
queryType = t
} else {
// 处理未知类型,使用数字表示
queryType = fmt.Sprintf("TYPE%d", r.Question[0].Qtype)
}
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
return &requestInfo{
sourceIP: sourceIP,
domain: domain,
queryType: queryType,
qType: qType,
queryAttempts: []string{domain},
}
}
// checkRequestConditions 检查请求条件,返回是否需要提前响应
func (s *Server) checkRequestConditions(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 检查是否是AAAA记录查询且IPv6解析已禁用
if reqInfo.qType == dns.TypeAAAA && !s.config.EnableIPv6 {
// 返回空的成功响应,而不是NXDOMAIN
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeSuccess)
w.WriteMsg(response)
// 更新统计信息 - 视为正常解析
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.Allowed++
stats.TotalResponseTime += responseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "allowed", "", "", false, false, true, "", "", nil, dns.RcodeSuccess)
logger.Debug("IPv6解析已禁用,返回空的成功响应", "domain", reqInfo.domain)
return true
}
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
// 计算实际响应时间
responseTime := time.Since(startTime).Milliseconds()
// 更新统计信息 - 视为错误
s.updateStats(func(stats *Stats) {
stats.Errors++
stats.TotalResponseTime += responseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused)
return true
}
return false
}
// handleLocalRules 处理本地规则(hosts文件、GFWList、屏蔽规则),返回是否已处理
func (s *Server) handleLocalRules(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 本地规则匹配的响应时间极短,使用固定值1ms
const localResponseTime int64 = 1
// 检查 hosts 文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(reqInfo.domain); exists {
s.handleHostsResponse(w, r, ip)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - hosts 文件匹配
hostsAnswers := []DNSAnswer{{
Type: "A",
Value: ip,
TTL: 300,
}}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "allowed", "", "", false, false, true, "hosts", "无", hostsAnswers, dns.RcodeSuccess)
return true
}
// 检查是否为GFWList域名(仅当GFWList功能启用时)
if s.gfwConfig.Enabled && s.gfwManager != nil && s.gfwManager.IsMatch(reqInfo.domain) {
s.handleGFWListResponse(w, r, reqInfo.domain)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - GFWList域名
gfwAnswers := []DNSAnswer{}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "gfwlist", "", "", false, false, true, "GFWList", "无", gfwAnswers, dns.RcodeSuccess)
return true
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(reqInfo.domain) {
// 获取屏蔽详情
blockDetails := s.shieldManager.CheckDomainBlockDetails(reqInfo.domain)
blockRule, _ := blockDetails["blockRule"].(string)
blockType, _ := blockDetails["blockRuleType"].(string)
s.handleBlockedResponse(w, r, reqInfo.domain)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - 被屏蔽域名
blockedAnswers := []DNSAnswer{}
// 根据屏蔽方法确定响应代码
blockedRcode := dns.RcodeNameError // 默认NXDOMAIN
if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" {
blockedRcode = dns.RcodeRefused
} else if blockMethod == "emptyIP" || blockMethod == "customIP" {
blockedRcode = dns.RcodeSuccess
}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode)
return true
}
return false
}
// handleCacheResponse 尝试从缓存获取响应,返回是否已处理
func (s *Server) handleCacheResponse(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 检查缓存中是否有响应(优先查找带DNSSEC的缓存项)
var cachedResponse *dns.Msg
var found bool
var cachedDNSSEC bool
// 1. 首先检查是否有普通缓存项
if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, reqInfo.qType); tempFound {
cachedResponse = tempResponse
found = tempFound
cachedDNSSEC = s.hasDNSSECRecords(tempResponse)
}
// 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项,
// 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录
// (这里可以进一步优化,比如在缓存中标记DNSSEC状态,快速查找)
if s.config.EnableDNSSEC && !cachedDNSSEC {
// 目前的缓存实现不支持按DNSSEC状态查找,所以这里暂时跳过
// 后续可以考虑改进缓存实现,添加DNSSEC状态标记
}
if !found {
return false
}
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
// 如果客户端请求包含EDNS记录,确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil {
// 添加EDNS记录,使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range cachedResponseCopy.Extra {
if cachedResponseCopy.Extra[i] == respOpt {
cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录,使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 确保响应的Question部分与客户端请求的Question部分匹配
cachedResponseCopy.Question = r.Question
// 修复:如果响应包含记录,确保Rcode为成功
hasValidRecords := false
// 检查Answer部分
if len(cachedResponseCopy.Answer) > 0 {
hasValidRecords = true
} else if len(cachedResponseCopy.Ns) > 0 {
// 检查Ns部分
hasValidRecords = true
} else if len(cachedResponseCopy.Extra) > 0 {
// 检查Extra部分,排除OPT记录
for _, rr := range cachedResponseCopy.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
cachedResponseCopy.Rcode = dns.RcodeSuccess
}
w.WriteMsg(cachedResponseCopy)
// 缓存命中的响应时间应该是极短的,使用固定值1ms而非实际处理时间
const cacheResponseTime int64 = 1
// 缓存命中的响应视为正常解析
s.updateStats(func(stats *Stats) {
stats.Allowed++
stats.TotalResponseTime += cacheResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 如果缓存响应包含DNSSEC记录,更新DNSSEC查询计数
if cachedDNSSEC {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
// 缓存响应视为DNSSEC成功
stats.DNSSECSuccess++
})
}
// 从缓存响应中提取解析记录
cachedAnswers := []DNSAnswer{}
if cachedResponse != nil {
for _, rr := range cachedResponse.Answer {
cachedAnswers = append(cachedAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为缓存
// 从缓存响应中获取响应代码
cacheRcode := dns.RcodeSuccess // 默认成功
if cachedResponse != nil {
cacheRcode = cachedResponse.Rcode
}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, cacheResponseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode)
logger.Debug("从缓存返回DNS响应", "domain", reqInfo.domain, "type", reqInfo.queryType, "dnssec", cachedDNSSEC)
return true
}
// handleUpstreamRequest 处理上游请求
func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) {
logger.Debug("开始处理上游请求", "domain", reqInfo.domain, "type", reqInfo.queryType)
// 缓存未命中,处理 DNS 请求
var response *dns.Msg
var rtt time.Duration
var dnsServer string
var dnssecServer string
// 直接查询原始域名
response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, reqInfo.domain)
logger.Debug("上游请求返回", "domain", reqInfo.domain, "response", response != nil, "rtt", rtt)
if response != nil {
// 如果客户端请求包含EDNS记录,确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := response.IsEdns0(); respOpt == nil {
// 添加EDNS记录,使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range response.Extra {
if response.Extra[i] == respOpt {
response.Extra = append(response.Extra[:i], response.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录,使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 确保响应的Question部分与客户端请求的Question部分匹配
response.Question = r.Question
// 设置递归可用标志(因为我们的 DNS 服务器支持递归查询)
response.RecursionAvailable = true
response.RecursionDesired = r.RecursionDesired
// 修复:如果响应包含记录,确保Rcode为成功
hasValidRecords := false
// 检查Answer部分
if len(response.Answer) > 0 {
hasValidRecords = true
} else if len(response.Ns) > 0 {
// 检查Ns部分
hasValidRecords = true
} else if len(response.Extra) > 0 {
// 检查Extra部分,排除OPT记录
for _, rr := range response.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
response.Rcode = dns.RcodeSuccess
}
// 写入响应给客户端
w.WriteMsg(response)
}
// 使用上游服务器的实际响应时间(转换为毫秒)
responseTime := int64(rtt.Milliseconds())
// 如果rtt为0(查询失败),则使用本地计算的时间
if responseTime == 0 {
responseTime = time.Since(startTime).Milliseconds()
}
// 添加合理性检查,避免异常大的响应时间影响统计
if responseTime > 60000 { // 超过60秒的响应时间视为异常
responseTime = 60000
}
// 更新基本统计
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
// 添加防御性编程,确保Queries大于0
if stats.Queries > 0 {
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
stats.AvgResponseTime = float64(math.Round(avg))
// 限制平均响应时间的范围,避免显示异常大的值
if stats.AvgResponseTime > 60000 {
stats.AvgResponseTime = 60000
}
}
})
// 判断请求结果类型并更新相应统计
resultType := "allowed"
if response == nil {
// 响应为nil,视为错误
resultType = "error"
s.updateStats(func(stats *Stats) {
stats.Errors++
})
} else if response.Rcode != dns.RcodeSuccess {
// 响应代码不是成功,视为错误
resultType = "error"
s.updateStats(func(stats *Stats) {
stats.Errors++
})
} else {
// 成功响应,视为正常解析
resultType = "allowed"
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// 检查响应是否包含DNSSEC记录并验证结果
responseDNSSEC := false
if response != nil {
// 使用hasDNSSECRecords函数检查是否包含DNSSEC记录
responseDNSSEC = s.hasDNSSECRecords(response)
// 检查AD标志,确认DNSSEC验证是否成功
if response.AuthenticatedData {
responseDNSSEC = true
}
// 更新域名的DNSSEC状态
if responseDNSSEC {
s.updateDomainDNSSECStatus(reqInfo.domain, true)
}
}
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL,不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
// 1. 缓存原始域名的查询结果
s.DnsCache.Set(r.Question[0].Name, reqInfo.qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", reqInfo.domain, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
// 2. 如果响应包含CNAME记录,同时缓存CNAME指向的域名的查询结果
for _, rr := range response.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
// 为CNAME指向的域名创建缓存
cnameQuery := r.Copy()
cnameQuery.Question[0].Name = cname.Target
s.DnsCache.Set(cname.Target, reqInfo.qType, responseCopy, defaultCacheTTL)
logger.Debug("CNAME响应已缓存", "domain", cname.Target, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
break
}
}
}
// 从响应中提取解析记录
responseAnswers := []DNSAnswer{}
if response != nil {
for _, rr := range response.Answer {
responseAnswers = append(responseAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 从响应中获取响应代码
realRcode := dns.RcodeSuccess // 默认成功
if response != nil {
realRcode = response.Rcode
}
logger.Debug("准备添加查询日志", "domain", reqInfo.domain, "result", resultType, "responseCode", realRcode)
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, resultType, "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
logger.Debug("查询日志已添加", "domain", reqInfo.domain)
}
// handleHostsResponse 处理 hosts 文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码 RecursionAvailable,使用默认值或上游返回的值
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
// 本地 hosts 匹配响应时间极短,使用固定值 1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleGFWListResponse 处理GFWList域名响应
func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("GFWList域名解析", "domain", domain, "client", w.RemoteAddr(), "ip", s.gfwConfig.IP)
// 更新解析域名统计
s.updateResolvedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(s.gfwConfig.IP)
response.Answer = append(response.Answer, answer)
}
w.WriteMsg(response)
// GFWList域名匹配响应时间极短,使用固定值1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
// 使用自定义屏蔽IP
if customBlockIP != "" {
answer.A = net.ParseIP(customBlockIP)
} else {
// 如果没有配置,使用0.0.0.0
answer.A = net.ParseIP("0.0.0.0")
}
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应(域名不存在)
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
// 屏蔽规则匹配响应时间极短,使用固定值1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
// serverResponse 用于存储服务器响应的结构体
type serverResponse struct {
response *dns.Msg
rtt time.Duration
server string
error error
}
// updateDNSSECServerMap 更新DNSSEC专用服务器映射,用于快速查找
func (s *Server) updateDNSSECServerMap() {
// 清空现有映射
for k := range s.dnssecServerMap {
delete(s.dnssecServerMap, k)
}
// 添加所有DNSSEC专用服务器到映射
for _, server := range s.config.DNSSECUpstreamDNS {
s.dnssecServerMap[server] = true
}
}
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) {
// 始终支持EDNS
var udpSize uint16 = 4096
var doFlag bool = s.config.EnableDNSSEC
// 检查域名是否匹配不验证DNSSEC的模式
noDNSSEC := false
for _, pattern := range s.config.NoDNSSECDomains {
if strings.Contains(domain, pattern) {
noDNSSEC = true
doFlag = false
logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern)
break
}
}
// 检查客户端请求是否包含EDNS记录
if opt := r.IsEdns0(); opt != nil {
// 保留客户端的UDP缓冲区大小
udpSize = opt.UDPSize()
// 移除现有的EDNS记录,以便重新添加
for i := range r.Extra {
if r.Extra[i] == opt {
r.Extra = append(r.Extra[:i], r.Extra[i+1:]...)
break
}
}
}
// 添加EDNS记录,设置适当的UDPSize和DO标志
r.SetEdns0(udpSize, doFlag)
// DNSSEC专用服务器列表,从配置中获取
dnssecServers := s.config.DNSSECUpstreamDNS
// 选择合适的上游DNS服务器列表
// 1. 首先检查是否有域名特定的DNS服务器配置
var selectedUpstreamDNS []string
var domainMatched bool
for matchStr, dnsServers := range s.config.DomainSpecificDNS {
if strings.Contains(domain, matchStr) {
selectedUpstreamDNS = dnsServers
domainMatched = true
logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers)
break
}
}
// 2. 如果没有匹配的域名特定配置
if !domainMatched {
// 创建一个新的切片来存储最终的上游服务器列表
var finalUpstreamDNS []string
// 首先添加用户配置的上游DNS服务器
finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...)
logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS)
// 如果启用了DNSSEC且有配置DNSSEC专用服务器,并且域名不匹配NoDNSSECDomains,则将DNSSEC专用服务器添加到列表中
if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC {
// 合并DNSSEC专用服务器到上游服务器列表,避免重复,并确保包含端口号
for _, dnssecServer := range s.config.DNSSECUpstreamDNS {
hasDuplicate := false
// 确保DNSSEC服务器地址包含端口号
normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer)
for _, upstream := range finalUpstreamDNS {
if upstream == normalizedDnssecServer {
hasDuplicate = true
break
}
}
if !hasDuplicate {
finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer)
}
}
logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS)
}
// 使用最终合并后的服务器列表
selectedUpstreamDNS = finalUpstreamDNS
}
// 1. 首先尝试所有配置的上游DNS服务器
var bestResponse *dns.Msg
var bestRtt time.Duration
var hasBestResponse bool
var hasDNSSECResponse bool
var backupResponse *dns.Msg
var backupRtt time.Duration
var hasBackup bool
var usedDNSServer string
var usedDNSSECServer string
// 使用配置中的超时时间
defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond
// 根据查询模式处理请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式 - 返回第一个成功响应
responses := make(chan serverResponse, len(selectedUpstreamDNS))
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 从池中获取客户端实例
client := s.clientPool.Get().(*dns.Client)
// 设置客户端参数(确保在 Exchange 之前设置,避免竞态条件)
client.Net = s.resolver.Net
client.UDPSize = s.resolver.UDPSize
client.Timeout = defaultTimeout // 使用配置的超时时间
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
// 将客户端实例放回池中(不重置 Timeout,因为下次使用时会重新设置)
s.clientPool.Put(client)
}(upstream)
}
// 等待所有请求完成
go func() {
wg.Wait()
close(responses)
}()
// 处理响应,只返回第一个成功响应
var lastErrorResponse *dns.Msg
var lastErrorRtt time.Duration
var lastErrorServer string
for i := 0; i < len(selectedUpstreamDNS); i++ {
resp := <-responses
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名,始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
// 检查当前服务器是否是DNSSEC专用服务器
if _, isDNSSECServer := s.dnssecServerMap[resp.server]; isDNSSECServer {
usedDNSSECServer = resp.server
}
// 如果是成功响应,立即返回
if resp.response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录(如果需要)
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
bestResponse = resp.response
bestRtt = resp.rtt
usedDNSServer = resp.server
hasBestResponse = true
hasDNSSECResponse = containsDNSSEC
logger.Debug("返回第一个成功响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
} else {
// 保存最后一个错误响应
lastErrorResponse = resp.response
lastErrorRtt = resp.rtt
lastErrorServer = resp.server
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
// 如果所有服务器都失败,返回最后一个错误
if lastErrorResponse != nil {
bestResponse = lastErrorResponse
bestRtt = lastErrorRtt
usedDNSServer = lastErrorServer
hasBestResponse = true
logger.Debug("所有服务器都失败,返回最后一个错误响应", "domain", domain, "server", lastErrorServer)
}
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
case "fastest-ip":
// 最快的IP地址模式 - 通过ping测试选择最快服务器,只向一个服务器发送请求
// 1. 选择最快的服务器
fastestServer := s.selectFastestServer(selectedUpstreamDNS)
if fastestServer != "" {
// 从池中获取客户端实例
client := s.clientPool.Get().(*dns.Client)
// 设置客户端参数
client.Net = s.resolver.Net
client.UDPSize = s.resolver.UDPSize
client.Timeout = defaultTimeout
// 只向一个服务器发送请求
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(fastestServer))
// 将客户端实例放回池中
s.clientPool.Put(client)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
// 对于不验证DNSSEC的域名,始终设置AD标志为false
if noDNSSEC {
response.AuthenticatedData = false
}
// 验证DNSSEC记录(如果需要)
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 检查响应是否包含有效的记录,如果包含,将Rcode设置为成功
hasValidRecords := false
if len(response.Answer) > 0 {
hasValidRecords = true
} else if len(response.Ns) > 0 {
hasValidRecords = true
} else if len(response.Extra) > 0 {
for _, rr := range response.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
response.Rcode = dns.RcodeSuccess
}
// 设置最佳响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
usedDNSServer = fastestServer
if containsDNSSEC {
hasDNSSECResponse = true
}
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer {
usedDNSSECServer = fastestServer
}
logger.Debug("使用最快服务器返回响应", "domain", domain, "server", fastestServer, "rtt", rtt)
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
logger.Debug("最快服务器请求失败", "domain", domain, "server", fastestServer, "error", err)
}
}
default:
// 默认使用并行请求模式 - 实现快速返回和超时机制
responses := make(chan serverResponse, len(selectedUpstreamDNS))
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}, 1)
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
// 处理响应的协程
go func() {
var fastestResponse *dns.Msg
var fastestRtt time.Duration = defaultTimeout
var fastestServer string
var fastestDnssecServer string
var fastestHasDnssec bool
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
// 等待所有请求完成或超时
timer := time.NewTimer(defaultTimeout)
defer timer.Stop()
// 处理所有响应
for {
select {
case resp, ok := <-responses:
if !ok {
// 所有响应都已处理
goto doneProcessing
}
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名,始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
dnssecServerForResponse := ""
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(resp.server)]; isDNSSECServer {
dnssecServerForResponse = resp.server
}
// 如果响应成功或为NXDOMAIN
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
// 按Rcode分类添加到不同列表
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
} else {
nxdomainResponses = append(nxdomainResponses, resp.response)
}
// 快速返回逻辑:找到第一个有效响应或更快的响应
if resp.response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC的响应
if containsDNSSEC {
// 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快
if !fastestHasDnssec || resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
fastestHasDnssec = true
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志(Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
} else {
// 非DNSSEC响应,只有在还没有找到DNSSEC响应且当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志(Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else if resp.response.Rcode == dns.RcodeNameError {
// NXDOMAIN响应,只有在还没有找到响应或当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志(Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else {
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
if !hasBackup {
// 第一次保存备选响应
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
case <-timer.C:
// 超时,停止等待更多响应
goto doneProcessing
}
}
doneProcessing:
// 如果还没有发送结果,发送最快的响应
if fastestResponse != nil {
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
close(resultChan)
}()
// 等待所有请求完成(不阻塞主流程)
go func() {
wg.Wait()
close(responses)
}()
// 等待结果或超时
select {
case result := <-resultChan:
// 快速返回结果
bestResponse = result.response
bestRtt = result.rtt
usedDNSServer = result.usedServer
usedDNSSECServer = result.usedDnssecServer
hasBestResponse = true
hasDNSSECResponse = s.hasDNSSECRecords(result.response)
logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse)
case <-time.After(defaultTimeout):
// 超时,使用备选响应
logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout)
}
}
// 2. 当启用DNSSEC且没有找到带DNSSEC的响应时,向DNSSEC专用服务器发送请求
// 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains,则不使用DNSSEC专用服务器,只使用指定的DNS服务器
if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC {
logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain)
// 增加DNSSEC查询计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
})
// 无论查询模式是什么,DNSSEC验证都只使用加权随机选择一个服务器
selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers)
if selectedDnssecServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{response, rtt, err}
}()
var response *dns.Msg
var rtt time.Duration
var err error
// 使用超时获取结果
select {
case result := <-resultChan:
response, rtt, err = result.response, result.rtt, result.err
case <-time.After(defaultTimeout):
// 超时,不再等待
logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout)
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedDnssecServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 无论响应是否包含DNSSEC记录,只要使用了DNSSEC专用服务器,就设置usedDNSSECServer
usedDNSSECServer = selectedDnssecServer
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志(Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应,尤其是带有DNSSEC记录的
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应,优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt)
}
// 注意:如果DNSSEC专用服务器返回的响应不包含DNSSEC记录,
// 我们不会覆盖之前从upstreamDNS获取的响应,
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedDnssecServer, false, 0)
}
}
}
// 3. 返回最佳响应
if hasBestResponse {
// 检查最佳响应是否包含DNSSEC记录
bestHasDNSSEC := s.hasDNSSECRecords(bestResponse)
// 如果启用了DNSSEC且最佳响应不包含DNSSEC记录,尝试使用本地解析(使用upstreamDNS服务器)
// 但如果域名匹配了domainSpecificDNS配置,则不执行此逻辑,只使用指定的DNS服务器
if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched {
logger.Debug("最佳响应不包含DNSSEC记录,尝试使用本地解析(upstreamDNS", "domain", domain)
// 选择一个upstreamDNS服务器进行解析(使用加权随机算法)
localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
if localServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
// 创建临时的 resolver,设置超时时间
tempResolver := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout, // 使用配置的超时时间
}
resp, rtt, e := tempResolver.Exchange(r, normalizeDNSServerAddress(localServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{resp, rtt, e}
}()
var localResponse *dns.Msg
var rtt time.Duration
var err error
// 使用超时获取结果
select {
case result := <-resultChan:
localResponse, rtt, err = result.response, result.rtt, result.err
case <-time.After(defaultTimeout):
// 超时
logger.Debug("本地解析超时", "domain", domain, "server", localServer, "timeout", defaultTimeout)
// 超时后跳过本地解析
localResponse = nil
err = fmt.Errorf("timeout")
}
if err == nil && localResponse != nil {
// 更新服务器统计信息
s.updateServerStats(localServer, true, rtt)
// 检查是否包含DNSSEC记录
localHasDNSSEC := s.hasDNSSECRecords(localResponse)
// 验证DNSSEC记录(如果存在),但不影响最终响应
if localHasDNSSEC {
signatureValid := s.verifyDNSSEC(localResponse)
localResponse.AuthenticatedData = signatureValid
if signatureValid {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
s.updateDomainDNSSECStatus(domain, localHasDNSSEC)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
logger.Debug("使用本地解析结果(upstreamDNS", "domain", domain, "server", localServer, "rtt", rtt)
return localResponse, rtt, localServer, ""
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(localServer, false, 0)
}
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
if bestHasDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
} else {
s.updateDomainDNSSECStatus(domain, false)
}
// 检查响应是否包含CNAME记录,需要确保返回完整的解析链
if bestResponse != nil && bestResponse.Rcode == dns.RcodeSuccess {
// 处理多级CNAME,直到获取到最终的A/AAAA记录
maxCNAMELevels := 5 // 限制最大CNAME解析级数,防止循环解析
currentLevel := 0
// 循环处理CNAME记录
for currentLevel < maxCNAMELevels {
// 检查是否包含CNAME记录
var hasCNAME bool
var cnameTarget string
// 检查Answer部分,查找CNAME记录
for _, rr := range bestResponse.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
hasCNAME = true
cnameTarget = cname.Target
}
}
// 如果不包含CNAME记录,或者已经包含最终的A/AAAA记录,退出循环
var hasFinalRecord bool
for _, rr := range bestResponse.Answer {
switch rr.Header().Rrtype {
case dns.TypeA, dns.TypeAAAA:
hasFinalRecord = true
break
}
}
if !hasCNAME || hasFinalRecord {
break // 没有CNAME记录,或者已经有最终记录,退出循环
}
// 如果包含CNAME记录但没有最终IP,继续查询
logger.Debug("响应包含CNAME但没有最终IP,继续查询", "domain", domain, "cname", cnameTarget, "level", currentLevel)
// 创建新的查询请求,查询CNAME指向的域名
cnameQuery := r.Copy()
cnameQuery.Question[0].Name = cnameTarget
// 继续查询CNAME指向的域名
cnameResponse, _, cnameDnsServer, cnameDnssecServer := s.forwardDNSRequestWithCache(cnameQuery, cnameTarget)
if cnameResponse != nil && cnameResponse.Rcode == dns.RcodeSuccess {
// 合并CNAME响应的Answer部分到主响应
bestResponse.Answer = append(bestResponse.Answer, cnameResponse.Answer...)
// 合并CNAME响应的Ns部分到主响应
bestResponse.Ns = append(bestResponse.Ns, cnameResponse.Ns...)
// 合并CNAME响应的Extra部分到主响应,排除OPT记录
for _, rr := range cnameResponse.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
bestResponse.Extra = append(bestResponse.Extra, rr)
}
}
// 更新使用的DNS服务器信息
if cnameDnsServer != "" {
usedDNSServer = cnameDnsServer
}
if cnameDnssecServer != "" {
usedDNSSECServer = cnameDnssecServer
}
} else {
// 查询失败,退出循环
break
}
// 增加CNAME解析级数
currentLevel++
}
if currentLevel >= maxCNAMELevels {
logger.Warn("CNAME解析级数超过限制,可能存在循环解析", "domain", domain, "maxLevels", maxCNAMELevels)
}
}
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
// 如果有备选响应,返回该响应
if hasBackup {
logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新统计信息
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return backupResponse, backupRtt, "", ""
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeServerFailure)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0, "", ""
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _, _, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 先尝试读锁,检查条目是否存在
s.blockedDomainsMutex.RLock()
entry, exists := s.blockedDomains[domain]
s.blockedDomainsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.blockedDomainsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.blockedDomains[domain]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: false,
}
}
s.blockedDomainsMutex.Unlock()
}
// 更新统计数据
now := time.Now()
// 更新小时统计
hourKey := now.Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
// 更新每日统计
dayKey := now.Format("2006-01-02")
s.dailyStatsMutex.Lock()
s.dailyStats[dayKey]++
s.dailyStatsMutex.Unlock()
// 更新每月统计
monthKey := now.Format("2006-01")
s.monthlyStatsMutex.Lock()
s.monthlyStats[monthKey]++
s.monthlyStatsMutex.Unlock()
}
// updateClientStats 更新客户端统计
func (s *Server) updateClientStats(ip string) {
// 先尝试读锁,检查条目是否存在
s.clientStatsMutex.RLock()
entry, exists := s.clientStats[ip]
s.clientStatsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.clientStatsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.clientStats[ip]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.clientStats[ip] = &ClientStats{
IP: ip,
Count: 1,
LastSeen: time.Now().UnixNano(),
}
}
s.clientStatsMutex.Unlock()
}
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func (s *Server) hasDNSSECRecords(response *dns.Msg) bool {
// 直接调用包内的hasDNSSECRecords函数,避免重复代码
return hasDNSSECRecords(response)
}
// verifyDNSSEC 验证DNSSEC签名
func (s *Server) verifyDNSSEC(response *dns.Msg) bool {
// 提取DNSKEY和RRSIG记录,并按类型和名称组织记录
dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY
rrsigs := make([]*dns.RRSIG, 0)
// 按 (名称, 类型) 组织记录集,用于快速查找
rrSets := make(map[string]map[uint16][]dns.RR) // name -> type -> records
// 定义处理单个记录的函数
processRecord := func(rr dns.RR) {
num := rr.Header().Rrtype
name := rr.Header().Name
// 组织记录集
if _, exists := rrSets[name]; !exists {
rrSets[name] = make(map[uint16][]dns.RR)
}
if _, exists := rrSets[name][num]; !exists {
rrSets[name][num] = make([]dns.RR, 0)
}
rrSets[name][num] = append(rrSets[name][num], rr)
// 特别处理DNSKEY和RRSIG
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
// 一次遍历所有响应部分,同时完成记录收集和组织
for _, rr := range response.Answer {
processRecord(rr)
}
for _, rr := range response.Ns {
processRecord(rr)
}
for _, rr := range response.Extra {
processRecord(rr)
}
// 如果没有RRSIG记录,验证失败
if len(rrsigs) == 0 {
return false
}
// 验证所有RRSIG记录
signatureValid := true
// 用于记录已经警告过的DNSKEY tag,避免重复警告
warnedKeyTags := make(map[uint16]bool)
for _, rrsig := range rrsigs {
// 查找对应的DNSKEY
dnskey, exists := dnskeys[rrsig.KeyTag]
if !exists {
// 仅当该key_tag尚未警告过时才记录警告
if !warnedKeyTags[rrsig.KeyTag] {
logger.Warn("DNSSEC验证失败:找不到对应的DNSKEY", "key_tag", rrsig.KeyTag)
warnedKeyTags[rrsig.KeyTag] = true
}
signatureValid = false
continue
}
// 快速查找需要验证的记录集
name := rrsig.Header().Name
typeCovered := rrsig.TypeCovered
rrset := rrSets[name][typeCovered]
// 验证签名
if len(rrset) > 0 {
err := rrsig.Verify(dnskey, rrset)
if err != nil {
logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag)
signatureValid = false
} else {
logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag)
}
}
}
return signatureValid
}
// updateDomainDNSSECStatus 更新域名的DNSSEC状态
func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) {
// 确保域名是小写
domain = strings.ToLower(domain)
// 更新域名的DNSSEC状态
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
// 更新resolvedDomains中的DNSSEC状态
if entry, exists := s.resolvedDomains[domain]; exists {
entry.DNSSEC = dnssec
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: dnssec,
}
}
// 更新domainDNSSECStatus映射(使用单独的锁)
s.domainDNSSECStatusMutex.Lock()
s.domainDNSSECStatus[domain] = dnssec
s.domainDNSSECStatusMutex.Unlock()
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
// 先尝试读锁,检查条目是否存在
s.resolvedDomainsMutex.RLock()
entry, exists := s.resolvedDomains[domain]
s.resolvedDomainsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.resolvedDomainsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.resolvedDomains[domain]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: false,
}
}
s.resolvedDomainsMutex.Unlock()
}
}
// getServerStats 获取服务器统计信息,如果不存在则创建
func (s *Server) getServerStats(server string) *ServerStats {
s.serverStatsMutex.RLock()
stats, exists := s.serverStats[server]
s.serverStatsMutex.RUnlock()
if exists {
return stats
}
s.serverStatsMutex.Lock()
defer s.serverStatsMutex.Unlock()
if stats, exists := s.serverStats[server]; exists {
return stats
}
stats = &ServerStats{
SuccessCount: 0,
FailureCount: 0,
LastResponse: time.Now(),
ResponseTime: 0,
ConnectionSpeed: 0,
}
s.serverStats[server] = stats
return stats
}
// updateServerStats 更新服务器统计信息
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
stats := s.getServerStats(server)
// 使用原子操作更新成功和失败计数
if success {
successCount := atomic.AddInt64(&stats.SuccessCount, 1)
// 只在需要更新平均响应时间时获取锁
s.serverStatsMutex.Lock()
stats.LastResponse = time.Now()
// 更新平均响应时间(简单移动平均)
if successCount == 1 {
// 第一次成功,直接使用当前响应时间
stats.ResponseTime = rtt
} else {
// 使用纳秒进行计算以避免类型不匹配
prevTotal := stats.ResponseTime.Nanoseconds() * (successCount - 1)
newTotal := prevTotal + rtt.Nanoseconds()
stats.ResponseTime = time.Duration(newTotal / successCount)
}
s.serverStatsMutex.Unlock()
} else {
atomic.AddInt64(&stats.FailureCount, 1)
// 只更新LastResponse时获取锁
s.serverStatsMutex.Lock()
stats.LastResponse = time.Now()
s.serverStatsMutex.Unlock()
}
}
// selectWeightedRandomServer 加权随机选择服务器
func (s *Server) selectWeightedRandomServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
type serverWeight struct {
server string
weight int64
responseTime time.Duration
successCount int64
failureCount int64
}
var totalWeight int64
var totalResponseTime time.Duration
var validServers int
var currentWeight int64
serversInfo := make([]serverWeight, len(servers))
for i, server := range servers {
stats := s.getServerStats(server)
serversInfo[i] = serverWeight{
server: server,
responseTime: stats.ResponseTime,
successCount: atomic.LoadInt64(&stats.SuccessCount),
failureCount: atomic.LoadInt64(&stats.FailureCount),
}
if stats.ResponseTime > 0 {
totalResponseTime += stats.ResponseTime
validServers++
}
}
var avgResponseTime time.Duration
if validServers > 0 {
avgResponseTime = totalResponseTime / time.Duration(validServers)
} else {
avgResponseTime = 1 * time.Second
}
var randomGen = rand.New(rand.NewSource(time.Now().UnixNano()))
for i := range serversInfo {
baseWeight := serversInfo[i].successCount - serversInfo[i].failureCount*2
if baseWeight < 1 {
baseWeight = 1
}
var responseFactor int64 = 100
if serversInfo[i].responseTime > 0 {
if serversInfo[i].responseTime < avgResponseTime {
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
if factor > 200 {
factor = 200
}
responseFactor = factor
} else {
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
if factor < 50 {
factor = 50
}
responseFactor = factor
}
}
finalWeight := (baseWeight * responseFactor) / 100
if finalWeight < 1 {
finalWeight = 1
}
serversInfo[i].weight = finalWeight
totalWeight += finalWeight
}
random := randomGen.Int63n(totalWeight)
for _, sw := range serversInfo {
currentWeight += sw.weight
if random < currentWeight {
return sw.server
}
}
// 兜底返回第一个服务器
return servers[0]
}
// measureServerSpeed 测量服务器TCP连接速度
func (s *Server) measureServerSpeed(server string) time.Duration {
addr := server
if !strings.Contains(server, ":") {
addr = server + ":53"
}
startTime := time.Now()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
return 2 * time.Second
}
defer conn.Close()
connTime := time.Since(startTime)
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
s.serverStatsMutex.Unlock()
return connTime
}
// selectFastestServer 选择连接速度最快的服务器
func (s *Server) selectFastestServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 并行测量所有服务器的速度
type speedResult struct {
server string
speed time.Duration
}
results := make(chan speedResult, len(servers))
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(srv string) {
defer wg.Done()
speed := s.measureServerSpeed(srv)
results <- speedResult{srv, speed}
}(server)
}
// 等待所有测量完成
go func() {
wg.Wait()
close(results)
}()
// 找出最快的服务器
var fastestServer string
var fastestSpeed time.Duration = 2 * time.Second
for result := range results {
if result.speed < fastestSpeed {
fastestSpeed = result.speed
fastestServer = result.server
}
}
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
if fastestServer == "" {
fastestServer = servers[0]
}
return fastestServer
}
// calculateAvgResponseTime 计算平均响应时间
func calculateAvgResponseTime(totalResponseTime int64, queries int64) float64 {
if queries <= 0 {
return 0
}
avg := float64(totalResponseTime) / float64(queries)
avg = float64(math.Round(avg))
// 限制平均响应时间的范围
if avg > 60000 {
avg = 60000
}
return avg
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
// 创建日志记录
queryLog := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
DNSSEC: dnssec,
EDNS: edns,
DNSServer: dnsServer,
DNSSECServer: dnssecServer,
Answers: answers,
ResponseCode: responseCode,
}
// 使用新日志系统记录(如果已初始化)
if s.logManager != nil {
// 将 Answers 转换为 JSON 字符串
answersJSON := ""
if len(queryLog.Answers) > 0 {
data, err := json.Marshal(queryLog.Answers)
if err == nil {
answersJSON = string(data)
}
}
newLog := log.QueryLog{
Timestamp: queryLog.Timestamp,
ClientIP: queryLog.ClientIP,
Domain: queryLog.Domain,
QueryType: queryLog.QueryType,
ResponseTime: queryLog.ResponseTime,
Result: queryLog.Result,
BlockRule: queryLog.BlockRule,
BlockType: queryLog.BlockType,
FromCache: queryLog.FromCache,
DNSSEC: queryLog.DNSSEC,
EDNS: queryLog.EDNS,
DNSServer: queryLog.DNSServer,
DNSSECServer: queryLog.DNSSECServer,
Answers: answersJSON,
ResponseCode: queryLog.ResponseCode,
}
err := s.logManager.Log(newLog)
if err != nil {
logger.Error("新日志系统记录失败", "domain", queryLog.Domain, "error", err)
}
}
// 同时使用旧日志系统记录(兼容性)
// 发送到日志处理通道(阻塞式,确保日志不会丢失)
}
// GetStartTime 获取服务器启动时间
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 复制查询类型统计
queryTypesCopy := make(map[string]int64)
for k, v := range s.stats.QueryTypes {
queryTypesCopy[k] = v
}
// 复制来源IP统计
sourceIPsCopy := make(map[string]bool)
for ip := range s.stats.SourceIPs {
sourceIPsCopy[ip] = true
}
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
AvgResponseTime: s.stats.AvgResponseTime,
TotalResponseTime: s.stats.TotalResponseTime,
QueryTypes: queryTypesCopy,
SourceIPs: sourceIPsCopy,
CpuUsage: s.stats.CpuUsage,
DNSSECQueries: s.stats.DNSSECQueries,
DNSSECSuccess: s.stats.DNSSECSuccess,
DNSSECFailed: s.stats.DNSSECFailed,
DNSSECEnabled: s.stats.DNSSECEnabled,
}
}
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm, queryType string) []QueryLog {
// 优先使用新日志系统查询(如果已初始化)
if s.logManager != nil {
logger.Debug("使用新日志系统查询", "filter", resultFilter, "search", searchTerm, "queryType", queryType)
// 转换排序字段名称以匹配数据库字段
dbSortField := sortField
if sortField == "time" {
dbSortField = "timestamp"
} else if sortField == "clientIp" {
dbSortField = "client_ip"
} else if sortField == "responseTime" {
dbSortField = "response_time"
} else if sortField == "blockRule" {
dbSortField = "block_rule"
}
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
page := log.PageParams{
Limit: limit,
Offset: offset,
SortField: dbSortField,
SortDirection: sortDirection,
}
logs, total, err := s.logManager.QueryLogs(filter, page)
logger.Debug("新日志系统查询结果", "logs", len(logs), "total", total, "error", err)
if err == nil && len(logs) > 0 {
// 将新日志格式转换为旧格式
result := make([]QueryLog, 0, len(logs))
for _, newLog := range logs {
// 将 JSON 字符串解析为 DNSAnswer 数组
var answers []DNSAnswer
if newLog.Answers != "" {
json.Unmarshal([]byte(newLog.Answers), &answers)
}
oldLog := QueryLog{
Timestamp: newLog.Timestamp,
ClientIP: newLog.ClientIP,
Domain: newLog.Domain,
QueryType: newLog.QueryType,
ResponseTime: newLog.ResponseTime,
Result: newLog.Result,
BlockRule: newLog.BlockRule,
BlockType: newLog.BlockType,
FromCache: newLog.FromCache,
DNSSEC: newLog.DNSSEC,
EDNS: newLog.EDNS,
DNSServer: newLog.DNSServer,
DNSSECServer: newLog.DNSSECServer,
Answers: answers,
ResponseCode: newLog.ResponseCode,
}
result = append(result, oldLog)
}
return result
}
// 如果新系统查询失败或没有数据,返回空列表
logger.Debug("新日志系统查询失败或无数据", "error", err, "logs", len(logs))
}
// 返回空列表
return []QueryLog{}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
// 使用新日志系统获取总数
if s.logManager != nil {
stats, err := s.logManager.GetStats(log.TimeRange{})
if err == nil {
return int(stats.TotalQueries)
}
}
return 0
}
// GetQueryLogsCountWithFilter 获取带过滤条件的查询日志总数
func (s *Server) GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType string) int {
// 优先使用新日志系统查询(如果已初始化)
if s.logManager != nil {
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
page := log.PageParams{
Limit: 1, // 只需要总数,获取 1 条数据即可
Offset: 0,
SortField: "timestamp",
SortDirection: "desc",
}
_, total, err := s.logManager.QueryLogs(filter, page)
if err == nil {
return int(total)
}
// 如果新系统查询失败,返回 0
logger.Debug("新日志系统查询失败", "error", err)
}
return 0
}
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 计算统计数据
return map[string]interface{}{
"totalQueries": s.stats.Queries,
"blockedQueries": s.stats.Blocked,
"allowedQueries": s.stats.Allowed,
"errorQueries": s.stats.Errors,
"avgResponseTime": s.stats.AvgResponseTime,
"activeIPs": len(s.stats.SourceIPs),
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 计算30天前的时间戳
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix()
// 转换为切片并过滤最近30天的数据
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
// 只包含最近30天的数据
if entry.LastSeen >= thirtyDaysAgo {
domains = append(domains, *entry)
}
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 计算30天前的时间戳
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix()
// 转换为切片并过滤最近30天的数据
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
// 只包含最近30天的数据
if entry.LastSeen >= thirtyDaysAgo {
domains = append(domains, *entry)
}
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen > domains[j].LastSeen
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopClients 获取TOP客户端列表
func (s *Server) GetTopClients(limit int) []ClientStats {
s.clientStatsMutex.RLock()
defer s.clientStatsMutex.RUnlock()
// 转换为切片
clients := make([]ClientStats, 0, len(s.clientStats))
for _, entry := range s.clientStats {
clients = append(clients, *entry)
}
// 按请求次数排序
sort.Slice(clients, func(i, j int) bool {
return clients[i].Count > clients[j].Count
})
// 返回限制数量
if len(clients) > limit {
return clients[:limit]
}
return clients
}
// GetHourlyStats 获取每小时统计数据
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}
// GetDailyStats 获取每日统计数据
func (s *Server) GetDailyStats() map[string]int64 {
s.dailyStatsMutex.RLock()
defer s.dailyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.dailyStats {
result[k] = v
}
return result
}
// GetMonthlyStats 获取每月统计数据
func (s *Server) GetMonthlyStats() map[string]int64 {
s.monthlyStatsMutex.RLock()
defer s.monthlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.monthlyStats {
result[k] = v
}
return result
}
// checkThreatDetection 检查DNS查询是否存在威胁
func (s *Server) checkThreatDetection(sourceIP, domain, queryType string) {
if s.threatEngine == nil {
return
}
// 调用威胁检测引擎检查查询
alerts := s.threatEngine.CheckQuery(sourceIP, domain, queryType)
// 处理检测到的威胁
for _, alert := range alerts {
// 添加告警到告警管理器
s.alertManager.AddAlert(alert)
}
}
// GetAlerts 获取告警列表
func (s *Server) GetAlerts(limit, offset int, level string) []*threat.ThreatAlert {
if s.alertManager == nil {
return []*threat.ThreatAlert{}
}
return s.alertManager.GetAlerts(limit, offset, level)
}
// GetAlertCount 获取告警数量
func (s *Server) GetAlertCount(level string) int {
if s.alertManager == nil {
return 0
}
return s.alertManager.GetAlertCount(level)
}
// ResolveAlert 解决告警
func (s *Server) ResolveAlert(alertID, action string) bool {
if s.alertManager == nil {
return false
}
return s.alertManager.ResolveAlert(alertID, action)
}
// GetThreatDomains 获取所有威胁域名信息
func (s *Server) GetThreatDomains() []*threat.ThreatInfo {
if s.dbManager == nil {
return []*threat.ThreatInfo{}
}
return s.dbManager.GetAllThreatDomains()
}
// AddThreatDomain 添加威胁域名
func (s *Server) AddThreatDomain(threatType, name string, riskLevel int, domain string) error {
if s.dbManager == nil {
return nil
}
return s.dbManager.AddThreatDomain(threatType, name, riskLevel, domain)
}
// RemoveThreatDomain 删除威胁域名
func (s *Server) RemoveThreatDomain(domain string) error {
if s.dbManager == nil {
return nil
}
return s.dbManager.RemoveThreatDomain(domain)
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 检查IPv4内网地址
if ipv4 := parsedIP.To4(); ipv4 != nil {
// 10.0.0.0/8
if ipv4[0] == 10 {
return true
}
// 172.16.0.0/12
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
return true
}
// 192.168.0.0/16
if ipv4[0] == 192 && ipv4[1] == 168 {
return true
}
// 127.0.0.0/8 (localhost)
if ipv4[0] == 127 {
return true
}
// 169.254.0.0/16 (链路本地地址)
if ipv4[0] == 169 && ipv4[1] == 254 {
return true
}
return false
}
// 检查IPv6内网地址
// ::1/128 (localhost)
if parsedIP.IsLoopback() {
return true
}
// fc00::/7 (唯一本地地址)
if parsedIP[0]&0xfc == 0xfc {
return true
}
// fe80::/10 (链路本地地址)
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
return true
}
return false
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
// 检查文件是否存在
data, err := ioutil.ReadFile("data/stats.json")
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
// 只恢复有效数据,避免破坏统计关系
s.stats.Queries += statsData.Stats.Queries
s.stats.Blocked += statsData.Stats.Blocked
s.stats.Allowed += statsData.Stats.Allowed
s.stats.Errors += statsData.Stats.Errors
s.stats.TotalResponseTime += statsData.Stats.TotalResponseTime
s.stats.DNSSECQueries += statsData.Stats.DNSSECQueries
s.stats.DNSSECSuccess += statsData.Stats.DNSSECSuccess
s.stats.DNSSECFailed += statsData.Stats.DNSSECFailed
// 重新计算平均响应时间,确保一致性
if s.stats.Queries > 0 {
s.stats.AvgResponseTime = float64(s.stats.TotalResponseTime) / float64(s.stats.Queries)
// 限制平均响应时间的范围,避免显示异常大的值
if s.stats.AvgResponseTime > 60000 {
s.stats.AvgResponseTime = 60000
}
}
// 合并查询类型统计
for k, v := range statsData.Stats.QueryTypes {
s.stats.QueryTypes[k] += v
}
// 合并来源IP统计
for ip := range statsData.Stats.SourceIPs {
s.stats.SourceIPs[ip] = true
}
// 确保使用当前配置中的EnableDNSSEC值
s.stats.DNSSECEnabled = s.config.EnableDNSSEC
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
s.dailyStatsMutex.Lock()
if statsData.DailyStats != nil {
s.dailyStats = statsData.DailyStats
}
s.dailyStatsMutex.Unlock()
s.monthlyStatsMutex.Lock()
if statsData.MonthlyStats != nil {
s.monthlyStats = statsData.MonthlyStats
}
s.monthlyStatsMutex.Unlock()
// 加载客户端统计数据
s.clientStatsMutex.Lock()
if statsData.ClientStats != nil {
s.clientStats = statsData.ClientStats
}
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
}
// processLogs 异步处理日志记录
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
s.dailyStatsMutex.RLock()
statsData.DailyStats = make(map[string]int64)
for k, v := range s.dailyStats {
statsData.DailyStats[k] = v
}
s.dailyStatsMutex.RUnlock()
s.monthlyStatsMutex.RLock()
statsData.MonthlyStats = make(map[string]int64)
for k, v := range s.monthlyStats {
statsData.MonthlyStats[k] = v
}
s.monthlyStatsMutex.RUnlock()
// 复制客户端统计数据
s.clientStatsMutex.RLock()
statsData.ClientStats = make(map[string]*ClientStats)
for k, v := range s.clientStats {
statsData.ClientStats[k] = v
}
s.clientStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(statsFilePath, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
return
}
logger.Info("统计数据保存成功", "file", statsFilePath)
}
// startCpuUsageMonitor 启动 CPU 使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()
// 初始化
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// 存储上一次的CPU时间统计
var prevIdle, prevTotal uint64
for {
select {
case <-ticker.C:
// 获取真实的系统级CPU使用率
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
if err != nil {
// 如果获取失败,使用默认值
cpuUsage = 0.0
logger.Error("获取系统CPU使用率失败", "error", err)
}
s.updateStats(func(stats *Stats) {
stats.CpuUsage = cpuUsage
})
case <-s.ctx.Done():
return
}
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息
file, err := os.Open("/proc/stat")
if err != nil {
return 0, err
}
defer file.Close()
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
if err != nil {
return 0, err
}
// 计算总的CPU时间
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
idle := cpuIdle + cpuIowait
// 第一次调用时,只初始化值,不计算使用率
if *prevTotal == 0 || *prevIdle == 0 {
*prevIdle = idle
*prevTotal = total
return 0, nil
}
// 计算CPU使用率
idleDelta := idle - *prevIdle
totalDelta := total - *prevTotal
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
// 更新上一次的值
*prevIdle = idle
*prevTotal = total
return utilization, nil
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.SaveInterval <= 0 {
return
}
// 初始化定时器
}
// GetArchiveQueryEngine 获取归档查询引擎
func (s *Server) GetArchiveQueryEngine() *log.ArchiveQueryEngine {
return s.archiveQueryEngine
}
// GetArchiveManager 获取归档管理器
func (s *Server) GetArchiveManager() *log.ArchiveManager {
return s.archiveManager
}