Files
dns-server/dns/server.go
Alex Yang 073f1961b1 更新web
2026-01-21 09:46:49 +08:00

3187 lines
91 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"math"
"math/rand"
"net"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"dns-server/config"
"dns-server/gfw"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// 确保DNS服务器地址包含端口号默认添加53端口
func normalizeDNSServerAddress(address string) string {
// 检查地址是否已经包含端口号
if _, _, err := net.SplitHostPort(address); err != nil {
// 如果没有端口号添加默认的53端口
return net.JoinHostPort(address, "53")
}
// 已经有端口号,直接返回
return address
}
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen int64
DNSSEC bool // 是否使用了DNSSEC
}
// ClientStats 客户端统计
type ClientStats struct {
IP string
Count int64
LastSeen int64
}
// DNSAnswer DNS解析记录
type DNSAnswer struct {
Type string `json:"type"` // 记录类型
Value string `json:"value"` // 记录值
TTL uint32 `json:"ttl"` // 生存时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time `json:"timestamp"` // 查询时间
ClientIP string `json:"clientIP"` // 客户端IP
Location string `json:"location"` // IP地理位置国家 城市)
Domain string `json:"domain"` // 查询域名
QueryType string `json:"queryType"` // 查询类型
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
Result string `json:"result"` // 查询结果allowed, blocked, error
BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽)
BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽)
FromCache bool `json:"fromCache"` // 是否来自缓存
DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC
EDNS bool `json:"edns"` // 是否使用了EDNS
DNSServer string `json:"dnsServer"` // 使用的DNS服务器
DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器
Answers []DNSAnswer `json:"answers"` // 解析记录
ResponseCode int `json:"responseCode"` // DNS响应代码
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
ClientStats map[string]*ClientStats `json:"clientStats"`
HourlyStats map[string]int64 `json:"hourlyStats"`
DailyStats map[string]int64 `json:"dailyStats"`
MonthlyStats map[string]int64 `json:"monthlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// ServerStats 服务器统计信息
type ServerStats struct {
SuccessCount int64 // 成功查询次数
FailureCount int64 // 失败查询次数
LastResponse time.Time // 最后响应时间
ResponseTime time.Duration // 平均响应时间
ConnectionSpeed time.Duration // TCP连接速度
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
gfwConfig *config.GFWListConfig
gfwManager *gfw.GFWListManager
server *dns.Server
tcpServer *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
clientStatsMutex sync.RWMutex
clientStats map[string]*ClientStats // 用于记录客户端统计
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
dailyStatsMutex sync.RWMutex
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表
maxQueryLogs int // 最大保存日志数量
logChannel chan QueryLog // 日志处理通道
saveTicker *time.Ticker // 用于定时保存数据
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
// 域名DNSSEC状态映射表
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
domainDNSSECStatusMutex sync.RWMutex // 保护域名DNSSEC状态映射的互斥锁
// 上游服务器状态跟踪
serverStats map[string]*ServerStats // 服务器地址到状态的映射
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
// DNSSEC专用服务器映射用于快速查找
dnssecServerMap map[string]bool // DNSSEC专用服务器地址到布尔值的映射
// DNS客户端实例池用于并行查询
clientPool sync.Pool // 存储*dns.Client实例
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
AvgResponseTime float64 // 平均响应时间(ms)
TotalResponseTime int64 // 总响应时间
QueryTypes map[string]int64 // 查询类型统计
SourceIPs map[string]bool // 活跃来源IP
CpuUsage float64 // CPU使用率(%)
DNSSECQueries int64 // DNSSEC查询总数
DNSSECSuccess int64 // DNSSEC验证成功数
DNSSECFailed int64 // DNSSEC验证失败数
DNSSECEnabled bool // 是否启用了DNSSEC
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager, gfwConfig *config.GFWListConfig, gfwManager *gfw.GFWListManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值分钟
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
// 保存间隔(秒)
saveInterval := time.Duration(config.SaveInterval) * time.Second
// 最大和最小缓存TTL分钟
maxCacheTTL := time.Duration(config.MaxCacheTTL) * time.Minute
minCacheTTL := time.Duration(config.MinCacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
gfwConfig: gfwConfig,
gfwManager: gfwManager,
resolver: &dns.Client{
Net: "udp",
UDPSize: 4096, // 增加UDP缓冲区大小支持更大的DNSSEC响应
},
ctx: ctx,
cancel: cancel,
startTime: time.Now(), // 记录服务器启动时间
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
AvgResponseTime: 0,
TotalResponseTime: 0,
QueryTypes: make(map[string]int64),
SourceIPs: make(map[string]bool),
CpuUsage: 0,
DNSSECQueries: 0,
DNSSECSuccess: 0,
DNSSECFailed: 0,
DNSSECEnabled: config.EnableDNSSEC,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
clientStats: make(map[string]*ClientStats),
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
logChannel: make(chan QueryLog, 1000), // 日志处理通道缓冲区大小1000
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// DNS查询缓存初始化
DnsCache: NewDNSCache(cacheTTL, config.CacheMode, config.CacheSize, config.CacheFilePath, saveInterval, maxCacheTTL, minCacheTTL),
// 初始化域名DNSSEC状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
// 初始化DNSSEC专用服务器映射
dnssecServerMap: make(map[string]bool),
// 初始化DNS客户端实例池
clientPool: sync.Pool{
New: func() interface{} {
return &dns.Client{
Net: "udp",
UDPSize: 4096,
Timeout: 5 * time.Second, // 默认超时时间,会在使用时覆盖
}
},
},
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel
// 重新初始化saveDone通道
s.saveDone = make(chan struct{})
// 重置stopped标志
s.stoppedMutex.Lock()
s.stopped = false
s.stoppedMutex.Unlock()
// 更新服务器启动时间
s.startTime = time.Now()
s.server = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 保存TCP服务器实例以便在Stop方法中关闭
s.tcpServer = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动CPU使用率监控
go s.startCpuUsageMonitor()
// 启动自动保存功能
go s.startAutoSave()
// 更新DNSSEC专用服务器映射
s.updateDNSSECServerMap()
// 启动日志处理协程
go s.processLogs()
// 启动统计数据定期重置功能每24小时
go func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.resetStats()
case <-s.ctx.Done():
return
}
}
}()
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := s.tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// resetStats 重置统计数据
func (s *Server) resetStats() {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 只重置累计值,保留配置相关值
s.stats.TotalResponseTime = 0
s.stats.AvgResponseTime = 0
s.stats.Queries = 0
s.stats.Blocked = 0
s.stats.Allowed = 0
s.stats.Errors = 0
s.stats.DNSSECQueries = 0
s.stats.DNSSECSuccess = 0
s.stats.DNSSECFailed = 0
s.stats.QueryTypes = make(map[string]int64)
s.stats.SourceIPs = make(map[string]bool)
logger.Info("统计数据已重置")
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
s.stoppedMutex.Lock()
if s.stopped {
s.stoppedMutex.Unlock()
return // 服务器已经停止,直接返回
}
// 标记服务器为已停止状态
s.stopped = true
s.stoppedMutex.Unlock()
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
if s.tcpServer != nil {
s.tcpServer.Shutdown()
}
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
// 1. 初始化请求信息
reqInfo := s.initRequestInfo(w, r)
// 2. 检查基本请求条件
if earlyResponse := s.checkRequestConditions(w, r, startTime, reqInfo); earlyResponse {
return
}
// 3. 检查本地处理规则
if localHandled := s.handleLocalRules(w, r, startTime, reqInfo); localHandled {
return
}
// 4. 尝试从缓存获取响应
if cacheHandled := s.handleCacheResponse(w, r, startTime, reqInfo); cacheHandled {
return
}
// 5. 转发请求到上游服务器
s.handleUpstreamRequest(w, r, startTime, reqInfo)
}
// requestInfo 封装请求相关信息
type requestInfo struct {
sourceIP string
domain string
queryType string
qType uint16
queryAttempts []string
}
// initRequestInfo 初始化请求信息
func (s *Server) initRequestInfo(w dns.ResponseWriter, r *dns.Msg) *requestInfo {
// 获取来源IP
sourceIP := w.RemoteAddr().String()
// 提取IP地址部分去掉端口
if strings.HasPrefix(sourceIP, "[") {
// IPv6地址格式: [::1]:53
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
sourceIP = sourceIP[1:idx] // 去掉方括号
}
} else {
// IPv4地址格式: 127.0.0.1:53
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
sourceIP = sourceIP[:idx]
}
}
// 更新来源IP统计和Queries计数器
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
stats.SourceIPs[sourceIP] = true
})
// 更新客户端统计
s.updateClientStats(sourceIP)
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
return &requestInfo{
sourceIP: sourceIP,
domain: domain,
queryType: queryType,
qType: qType,
queryAttempts: []string{domain},
}
}
// checkRequestConditions 检查请求条件,返回是否需要提前响应
func (s *Server) checkRequestConditions(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 检查是否是AAAA记录查询且IPv6解析已禁用
if reqInfo.qType == dns.TypeAAAA && !s.config.EnableIPv6 {
// 返回空的成功响应而不是NXDOMAIN
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeSuccess)
w.WriteMsg(response)
// 更新统计信息 - 视为正常解析
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.Allowed++
stats.TotalResponseTime += responseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "allowed", "", "", false, false, true, "", "", nil, dns.RcodeSuccess)
logger.Debug("IPv6解析已禁用返回空的成功响应", "domain", reqInfo.domain)
return true
}
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
// 计算实际响应时间
responseTime := time.Since(startTime).Milliseconds()
// 更新统计信息 - 视为错误
s.updateStats(func(stats *Stats) {
stats.Errors++
stats.TotalResponseTime += responseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused)
return true
}
return false
}
// handleLocalRules 处理本地规则hosts文件、GFWList、屏蔽规则返回是否已处理
func (s *Server) handleLocalRules(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 本地规则匹配的响应时间极短使用固定值1ms
const localResponseTime int64 = 1
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(reqInfo.domain); exists {
s.handleHostsResponse(w, r, ip)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
return true
}
// 检查是否为GFWList域名仅当GFWList功能启用时
if s.gfwConfig.Enabled && s.gfwManager != nil && s.gfwManager.IsMatch(reqInfo.domain) {
s.handleGFWListResponse(w, r, reqInfo.domain)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - GFWList域名
gfwAnswers := []DNSAnswer{}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "gfwlist", "", "", false, false, true, "GFWList", "无", gfwAnswers, dns.RcodeSuccess)
return true
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(reqInfo.domain) {
// 获取屏蔽详情
blockDetails := s.shieldManager.CheckDomainBlockDetails(reqInfo.domain)
blockRule, _ := blockDetails["blockRule"].(string)
blockType, _ := blockDetails["blockRuleType"].(string)
s.handleBlockedResponse(w, r, reqInfo.domain)
// 使用固定的短响应时间
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += localResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 添加查询日志 - 被屏蔽域名
blockedAnswers := []DNSAnswer{}
// 根据屏蔽方法确定响应代码
blockedRcode := dns.RcodeNameError // 默认NXDOMAIN
if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" {
blockedRcode = dns.RcodeRefused
} else if blockMethod == "emptyIP" || blockMethod == "customIP" {
blockedRcode = dns.RcodeSuccess
}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, localResponseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode)
return true
}
return false
}
// handleCacheResponse 尝试从缓存获取响应,返回是否已处理
func (s *Server) handleCacheResponse(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) bool {
// 检查缓存中是否有响应优先查找带DNSSEC的缓存项
var cachedResponse *dns.Msg
var found bool
var cachedDNSSEC bool
// 1. 首先检查是否有普通缓存项
if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, reqInfo.qType); tempFound {
cachedResponse = tempResponse
found = tempFound
cachedDNSSEC = s.hasDNSSECRecords(tempResponse)
}
// 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项
// 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录
// 这里可以进一步优化比如在缓存中标记DNSSEC状态快速查找
if s.config.EnableDNSSEC && !cachedDNSSEC {
// 目前的缓存实现不支持按DNSSEC状态查找所以这里暂时跳过
// 后续可以考虑改进缓存实现添加DNSSEC状态标记
}
if !found {
return false
}
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range cachedResponseCopy.Extra {
if cachedResponseCopy.Extra[i] == respOpt {
cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 确保响应的Question部分与客户端请求的Question部分匹配
cachedResponseCopy.Question = r.Question
// 修复如果响应包含记录确保Rcode为成功
hasValidRecords := false
// 检查Answer部分
if len(cachedResponseCopy.Answer) > 0 {
hasValidRecords = true
} else if len(cachedResponseCopy.Ns) > 0 {
// 检查Ns部分
hasValidRecords = true
} else if len(cachedResponseCopy.Extra) > 0 {
// 检查Extra部分排除OPT记录
for _, rr := range cachedResponseCopy.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
cachedResponseCopy.Rcode = dns.RcodeSuccess
}
w.WriteMsg(cachedResponseCopy)
// 缓存命中的响应时间应该是极短的使用固定值1ms而非实际处理时间
const cacheResponseTime int64 = 1
// 缓存命中的响应视为正常解析
s.updateStats(func(stats *Stats) {
stats.Allowed++
stats.TotalResponseTime += cacheResponseTime
stats.AvgResponseTime = calculateAvgResponseTime(stats.TotalResponseTime, stats.Queries)
})
// 如果缓存响应包含DNSSEC记录更新DNSSEC查询计数
if cachedDNSSEC {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
// 缓存响应视为DNSSEC成功
stats.DNSSECSuccess++
})
}
// 从缓存响应中提取解析记录
cachedAnswers := []DNSAnswer{}
if cachedResponse != nil {
for _, rr := range cachedResponse.Answer {
cachedAnswers = append(cachedAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为缓存
// 从缓存响应中获取响应代码
cacheRcode := dns.RcodeSuccess // 默认成功
if cachedResponse != nil {
cacheRcode = cachedResponse.Rcode
}
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, cacheResponseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode)
logger.Debug("从缓存返回DNS响应", "domain", reqInfo.domain, "type", reqInfo.queryType, "dnssec", cachedDNSSEC)
return true
}
// handleUpstreamRequest 处理上游请求
func (s *Server) handleUpstreamRequest(w dns.ResponseWriter, r *dns.Msg, startTime time.Time, reqInfo *requestInfo) {
// 缓存未命中处理DNS请求
var response *dns.Msg
var rtt time.Duration
var dnsServer string
var dnssecServer string
// 直接查询原始域名
response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, reqInfo.domain)
if response != nil {
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := response.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range response.Extra {
if response.Extra[i] == respOpt {
response.Extra = append(response.Extra[:i], response.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 确保响应的Question部分与客户端请求的Question部分匹配
response.Question = r.Question
// 修复如果响应包含记录确保Rcode为成功
hasValidRecords := false
// 检查Answer部分
if len(response.Answer) > 0 {
hasValidRecords = true
} else if len(response.Ns) > 0 {
// 检查Ns部分
hasValidRecords = true
} else if len(response.Extra) > 0 {
// 检查Extra部分排除OPT记录
for _, rr := range response.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
response.Rcode = dns.RcodeSuccess
}
// 写入响应给客户端
w.WriteMsg(response)
}
// 使用上游服务器的实际响应时间(转换为毫秒)
responseTime := int64(rtt.Milliseconds())
// 如果rtt为0查询失败则使用本地计算的时间
if responseTime == 0 {
responseTime = time.Since(startTime).Milliseconds()
}
// 添加合理性检查,避免异常大的响应时间影响统计
if responseTime > 60000 { // 超过60秒的响应时间视为异常
responseTime = 60000
}
// 更新基本统计
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
// 添加防御性编程确保Queries大于0
if stats.Queries > 0 {
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
stats.AvgResponseTime = float64(math.Round(avg))
// 限制平均响应时间的范围,避免显示异常大的值
if stats.AvgResponseTime > 60000 {
stats.AvgResponseTime = 60000
}
}
})
// 判断请求结果类型并更新相应统计
resultType := "allowed"
if response == nil {
// 响应为nil视为错误
resultType = "error"
s.updateStats(func(stats *Stats) {
stats.Errors++
})
} else if response.Rcode != dns.RcodeSuccess {
// 响应代码不是成功,视为错误
resultType = "error"
s.updateStats(func(stats *Stats) {
stats.Errors++
})
} else {
// 成功响应,视为正常解析
resultType = "allowed"
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// 检查响应是否包含DNSSEC记录并验证结果
responseDNSSEC := false
if response != nil {
// 使用hasDNSSECRecords函数检查是否包含DNSSEC记录
responseDNSSEC = s.hasDNSSECRecords(response)
// 检查AD标志确认DNSSEC验证是否成功
if response.AuthenticatedData {
responseDNSSEC = true
}
// 更新域名的DNSSEC状态
if responseDNSSEC {
s.updateDomainDNSSECStatus(reqInfo.domain, true)
}
}
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
// 1. 缓存原始域名的查询结果
s.DnsCache.Set(r.Question[0].Name, reqInfo.qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", reqInfo.domain, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
// 2. 如果响应包含CNAME记录同时缓存CNAME指向的域名的查询结果
for _, rr := range response.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
// 为CNAME指向的域名创建缓存
cnameQuery := r.Copy()
cnameQuery.Question[0].Name = cname.Target
s.DnsCache.Set(cname.Target, reqInfo.qType, responseCopy, defaultCacheTTL)
logger.Debug("CNAME响应已缓存", "domain", cname.Target, "type", reqInfo.queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
break
}
}
}
// 从响应中提取解析记录
responseAnswers := []DNSAnswer{}
if response != nil {
for _, rr := range response.Answer {
responseAnswers = append(responseAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 从响应中获取响应代码
realRcode := dns.RcodeSuccess // 默认成功
if response != nil {
realRcode = response.Rcode
}
// 添加查询日志
s.addQueryLog(reqInfo.sourceIP, reqInfo.domain, reqInfo.queryType, responseTime, resultType, "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
// 本地hosts匹配响应时间极短使用固定值1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleGFWListResponse 处理GFWList域名响应
func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("GFWList域名解析", "domain", domain, "client", w.RemoteAddr(), "ip", s.gfwConfig.IP)
// 更新解析域名统计
s.updateResolvedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(s.gfwConfig.IP)
response.Answer = append(response.Answer, answer)
}
w.WriteMsg(response)
// GFWList域名匹配响应时间极短使用固定值1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
// 使用自定义屏蔽IP
if customBlockIP != "" {
answer.A = net.ParseIP(customBlockIP)
} else {
// 如果没有配置使用0.0.0.0
answer.A = net.ParseIP("0.0.0.0")
}
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
// 屏蔽规则匹配响应时间极短使用固定值1ms
const localResponseTime int64 = 1
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
// serverResponse 用于存储服务器响应的结构体
type serverResponse struct {
response *dns.Msg
rtt time.Duration
server string
error error
}
// updateDNSSECServerMap 更新DNSSEC专用服务器映射用于快速查找
func (s *Server) updateDNSSECServerMap() {
// 清空现有映射
for k := range s.dnssecServerMap {
delete(s.dnssecServerMap, k)
}
// 添加所有DNSSEC专用服务器到映射
for _, server := range s.config.DNSSECUpstreamDNS {
s.dnssecServerMap[server] = true
}
}
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) {
// 始终支持EDNS
var udpSize uint16 = 4096
var doFlag bool = s.config.EnableDNSSEC
// 检查域名是否匹配不验证DNSSEC的模式
noDNSSEC := false
for _, pattern := range s.config.NoDNSSECDomains {
if strings.Contains(domain, pattern) {
noDNSSEC = true
doFlag = false
logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern)
break
}
}
// 检查客户端请求是否包含EDNS记录
if opt := r.IsEdns0(); opt != nil {
// 保留客户端的UDP缓冲区大小
udpSize = opt.UDPSize()
// 移除现有的EDNS记录以便重新添加
for i := range r.Extra {
if r.Extra[i] == opt {
r.Extra = append(r.Extra[:i], r.Extra[i+1:]...)
break
}
}
}
// 添加EDNS记录设置适当的UDPSize和DO标志
r.SetEdns0(udpSize, doFlag)
// DNSSEC专用服务器列表从配置中获取
dnssecServers := s.config.DNSSECUpstreamDNS
// 选择合适的上游DNS服务器列表
// 1. 首先检查是否有域名特定的DNS服务器配置
var selectedUpstreamDNS []string
var domainMatched bool
for matchStr, dnsServers := range s.config.DomainSpecificDNS {
if strings.Contains(domain, matchStr) {
selectedUpstreamDNS = dnsServers
domainMatched = true
logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers)
break
}
}
// 2. 如果没有匹配的域名特定配置
if !domainMatched {
// 创建一个新的切片来存储最终的上游服务器列表
var finalUpstreamDNS []string
// 首先添加用户配置的上游DNS服务器
finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...)
logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS)
// 如果启用了DNSSEC且有配置DNSSEC专用服务器并且域名不匹配NoDNSSECDomains则将DNSSEC专用服务器添加到列表中
if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC {
// 合并DNSSEC专用服务器到上游服务器列表避免重复并确保包含端口号
for _, dnssecServer := range s.config.DNSSECUpstreamDNS {
hasDuplicate := false
// 确保DNSSEC服务器地址包含端口号
normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer)
for _, upstream := range finalUpstreamDNS {
if upstream == normalizedDnssecServer {
hasDuplicate = true
break
}
}
if !hasDuplicate {
finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer)
}
}
logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS)
}
// 使用最终合并后的服务器列表
selectedUpstreamDNS = finalUpstreamDNS
}
// 1. 首先尝试所有配置的上游DNS服务器
var bestResponse *dns.Msg
var bestRtt time.Duration
var hasBestResponse bool
var hasDNSSECResponse bool
var backupResponse *dns.Msg
var backupRtt time.Duration
var hasBackup bool
var usedDNSServer string
var usedDNSSECServer string
// 使用配置中的超时时间
defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond
// 根据查询模式处理请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式 - 返回第一个成功响应
responses := make(chan serverResponse, len(selectedUpstreamDNS))
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 从池中获取客户端实例
client := s.clientPool.Get().(*dns.Client)
// 设置客户端参数
client.Net = s.resolver.Net
client.UDPSize = s.resolver.UDPSize
client.Timeout = defaultTimeout
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
// 将客户端实例放回池中
s.clientPool.Put(client)
}(upstream)
}
// 等待所有请求完成
go func() {
wg.Wait()
close(responses)
}()
// 处理响应,只返回第一个成功响应
var lastErrorResponse *dns.Msg
var lastErrorRtt time.Duration
var lastErrorServer string
for i := 0; i < len(selectedUpstreamDNS); i++ {
resp := <-responses
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
// 检查当前服务器是否是DNSSEC专用服务器
if _, isDNSSECServer := s.dnssecServerMap[resp.server]; isDNSSECServer {
usedDNSSECServer = resp.server
}
// 如果是成功响应,立即返回
if resp.response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录如果需要
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
bestResponse = resp.response
bestRtt = resp.rtt
usedDNSServer = resp.server
hasBestResponse = true
hasDNSSECResponse = containsDNSSEC
logger.Debug("返回第一个成功响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
} else {
// 保存最后一个错误响应
lastErrorResponse = resp.response
lastErrorRtt = resp.rtt
lastErrorServer = resp.server
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
// 如果所有服务器都失败,返回最后一个错误
if lastErrorResponse != nil {
bestResponse = lastErrorResponse
bestRtt = lastErrorRtt
usedDNSServer = lastErrorServer
hasBestResponse = true
logger.Debug("所有服务器都失败,返回最后一个错误响应", "domain", domain, "server", lastErrorServer)
}
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
case "fastest-ip":
// 最快的IP地址模式 - 通过ping测试选择最快服务器只向一个服务器发送请求
// 1. 选择最快的服务器
fastestServer := s.selectFastestServer(selectedUpstreamDNS)
if fastestServer != "" {
// 从池中获取客户端实例
client := s.clientPool.Get().(*dns.Client)
// 设置客户端参数
client.Net = s.resolver.Net
client.UDPSize = s.resolver.UDPSize
client.Timeout = defaultTimeout
// 只向一个服务器发送请求
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(fastestServer))
// 将客户端实例放回池中
s.clientPool.Put(client)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
response.AuthenticatedData = false
}
// 验证DNSSEC记录如果需要
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 检查响应是否包含有效的记录如果包含将Rcode设置为成功
hasValidRecords := false
if len(response.Answer) > 0 {
hasValidRecords = true
} else if len(response.Ns) > 0 {
hasValidRecords = true
} else if len(response.Extra) > 0 {
for _, rr := range response.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasValidRecords = true
break
}
}
}
if hasValidRecords {
response.Rcode = dns.RcodeSuccess
}
// 设置最佳响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
usedDNSServer = fastestServer
if containsDNSSEC {
hasDNSSECResponse = true
}
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer {
usedDNSSECServer = fastestServer
}
logger.Debug("使用最快服务器返回响应", "domain", domain, "server", fastestServer, "rtt", rtt)
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
logger.Debug("最快服务器请求失败", "domain", domain, "server", fastestServer, "error", err)
}
}
default:
// 默认使用并行请求模式 - 实现快速返回和超时机制
responses := make(chan serverResponse, len(selectedUpstreamDNS))
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}, 1)
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
// 处理响应的协程
go func() {
var fastestResponse *dns.Msg
var fastestRtt time.Duration = defaultTimeout
var fastestServer string
var fastestDnssecServer string
var fastestHasDnssec bool
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
// 等待所有请求完成或超时
timer := time.NewTimer(defaultTimeout)
defer timer.Stop()
// 处理所有响应
for {
select {
case resp, ok := <-responses:
if !ok {
// 所有响应都已处理
goto doneProcessing
}
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
dnssecServerForResponse := ""
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(resp.server)]; isDNSSECServer {
dnssecServerForResponse = resp.server
}
// 如果响应成功或为NXDOMAIN
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
// 按Rcode分类添加到不同列表
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
} else {
nxdomainResponses = append(nxdomainResponses, resp.response)
}
// 快速返回逻辑:找到第一个有效响应或更快的响应
if resp.response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC的响应
if containsDNSSEC {
// 如果这是第一个DNSSEC响应或者比当前最快的DNSSEC响应更快
if !fastestHasDnssec || resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
fastestHasDnssec = true
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
} else {
// 非DNSSEC响应只有在还没有找到DNSSEC响应且当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else if resp.response.Rcode == dns.RcodeNameError {
// NXDOMAIN响应只有在还没有找到响应或当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else {
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
if !hasBackup {
// 第一次保存备选响应
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
case <-timer.C:
// 超时,停止等待更多响应
goto doneProcessing
}
}
doneProcessing:
// 如果还没有发送结果,发送最快的响应
if fastestResponse != nil {
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
close(resultChan)
}()
// 等待所有请求完成(不阻塞主流程)
go func() {
wg.Wait()
close(responses)
}()
// 等待结果或超时
select {
case result := <-resultChan:
// 快速返回结果
bestResponse = result.response
bestRtt = result.rtt
usedDNSServer = result.usedServer
usedDNSSECServer = result.usedDnssecServer
hasBestResponse = true
hasDNSSECResponse = s.hasDNSSECRecords(result.response)
logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse)
case <-time.After(defaultTimeout):
// 超时,使用备选响应
logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout)
}
}
// 2. 当启用DNSSEC且没有找到带DNSSEC的响应时向DNSSEC专用服务器发送请求
// 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains则不使用DNSSEC专用服务器只使用指定的DNS服务器
if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC {
logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain)
// 增加DNSSEC查询计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
})
// 无论查询模式是什么DNSSEC验证都只使用加权随机选择一个服务器
selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers)
if selectedDnssecServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{response, rtt, err}
}()
var response *dns.Msg
var rtt time.Duration
var err error
// 使用超时获取结果
select {
case result := <-resultChan:
response, rtt, err = result.response, result.rtt, result.err
case <-time.After(defaultTimeout):
// 超时,不再等待
logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout)
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedDnssecServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 无论响应是否包含DNSSEC记录只要使用了DNSSEC专用服务器就设置usedDNSSECServer
usedDNSSECServer = selectedDnssecServer
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedDnssecServer, false, 0)
}
}
}
// 3. 返回最佳响应
if hasBestResponse {
// 检查最佳响应是否包含DNSSEC记录
bestHasDNSSEC := s.hasDNSSECRecords(bestResponse)
// 如果启用了DNSSEC且最佳响应不包含DNSSEC记录尝试使用本地解析使用upstreamDNS服务器
// 但如果域名匹配了domainSpecificDNS配置则不执行此逻辑只使用指定的DNS服务器
if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched {
logger.Debug("最佳响应不包含DNSSEC记录尝试使用本地解析upstreamDNS", "domain", domain)
// 选择一个upstreamDNS服务器进行解析使用加权随机算法
localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
if localServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{resp, r, e}
}()
var localResponse *dns.Msg
var rtt time.Duration
var err error
// 直接获取结果,不使用上下文超时
result := <-resultChan
localResponse, rtt, err = result.response, result.rtt, result.err
if err == nil && localResponse != nil {
// 更新服务器统计信息
s.updateServerStats(localServer, true, rtt)
// 检查是否包含DNSSEC记录
localHasDNSSEC := s.hasDNSSECRecords(localResponse)
// 验证DNSSEC记录如果存在但不影响最终响应
if localHasDNSSEC {
signatureValid := s.verifyDNSSEC(localResponse)
localResponse.AuthenticatedData = signatureValid
if signatureValid {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
s.updateDomainDNSSECStatus(domain, localHasDNSSEC)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
logger.Debug("使用本地解析结果upstreamDNS", "domain", domain, "server", localServer, "rtt", rtt)
return localResponse, rtt, localServer, ""
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(localServer, false, 0)
}
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
if bestHasDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
} else {
s.updateDomainDNSSECStatus(domain, false)
}
// 检查响应是否包含CNAME记录需要确保返回完整的解析链
if bestResponse != nil && bestResponse.Rcode == dns.RcodeSuccess {
// 处理多级CNAME直到获取到最终的A/AAAA记录
maxCNAMELevels := 5 // 限制最大CNAME解析级数防止循环解析
currentLevel := 0
// 循环处理CNAME记录
for currentLevel < maxCNAMELevels {
// 检查是否包含CNAME记录
var hasCNAME bool
var cnameTarget string
// 检查Answer部分查找CNAME记录
for _, rr := range bestResponse.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
hasCNAME = true
cnameTarget = cname.Target
}
}
// 如果不包含CNAME记录或者已经包含最终的A/AAAA记录退出循环
var hasFinalRecord bool
for _, rr := range bestResponse.Answer {
switch rr.Header().Rrtype {
case dns.TypeA, dns.TypeAAAA:
hasFinalRecord = true
break
}
}
if !hasCNAME || hasFinalRecord {
break // 没有CNAME记录或者已经有最终记录退出循环
}
// 如果包含CNAME记录但没有最终IP继续查询
logger.Debug("响应包含CNAME但没有最终IP继续查询", "domain", domain, "cname", cnameTarget, "level", currentLevel)
// 创建新的查询请求查询CNAME指向的域名
cnameQuery := r.Copy()
cnameQuery.Question[0].Name = cnameTarget
// 继续查询CNAME指向的域名
cnameResponse, _, cnameDnsServer, cnameDnssecServer := s.forwardDNSRequestWithCache(cnameQuery, cnameTarget)
if cnameResponse != nil && cnameResponse.Rcode == dns.RcodeSuccess {
// 合并CNAME响应的Answer部分到主响应
bestResponse.Answer = append(bestResponse.Answer, cnameResponse.Answer...)
// 合并CNAME响应的Ns部分到主响应
bestResponse.Ns = append(bestResponse.Ns, cnameResponse.Ns...)
// 合并CNAME响应的Extra部分到主响应排除OPT记录
for _, rr := range cnameResponse.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
bestResponse.Extra = append(bestResponse.Extra, rr)
}
}
// 更新使用的DNS服务器信息
if cnameDnsServer != "" {
usedDNSServer = cnameDnsServer
}
if cnameDnssecServer != "" {
usedDNSSECServer = cnameDnssecServer
}
} else {
// 查询失败,退出循环
break
}
// 增加CNAME解析级数
currentLevel++
}
if currentLevel >= maxCNAMELevels {
logger.Warn("CNAME解析级数超过限制可能存在循环解析", "domain", domain, "maxLevels", maxCNAMELevels)
}
}
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
// 如果有备选响应,返回该响应
if hasBackup {
logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新统计信息
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return backupResponse, backupRtt, "", ""
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeServerFailure)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0, "", ""
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _, _, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 先尝试读锁,检查条目是否存在
s.blockedDomainsMutex.RLock()
entry, exists := s.blockedDomains[domain]
s.blockedDomainsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.blockedDomainsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.blockedDomains[domain]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: false,
}
}
s.blockedDomainsMutex.Unlock()
}
// 更新统计数据
now := time.Now()
// 更新小时统计
hourKey := now.Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
// 更新每日统计
dayKey := now.Format("2006-01-02")
s.dailyStatsMutex.Lock()
s.dailyStats[dayKey]++
s.dailyStatsMutex.Unlock()
// 更新每月统计
monthKey := now.Format("2006-01")
s.monthlyStatsMutex.Lock()
s.monthlyStats[monthKey]++
s.monthlyStatsMutex.Unlock()
}
// updateClientStats 更新客户端统计
func (s *Server) updateClientStats(ip string) {
// 先尝试读锁,检查条目是否存在
s.clientStatsMutex.RLock()
entry, exists := s.clientStats[ip]
s.clientStatsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.clientStatsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.clientStats[ip]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.clientStats[ip] = &ClientStats{
IP: ip,
Count: 1,
LastSeen: time.Now().UnixNano(),
}
}
s.clientStatsMutex.Unlock()
}
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func (s *Server) hasDNSSECRecords(response *dns.Msg) bool {
// 直接调用包内的hasDNSSECRecords函数避免重复代码
return hasDNSSECRecords(response)
}
// verifyDNSSEC 验证DNSSEC签名
func (s *Server) verifyDNSSEC(response *dns.Msg) bool {
// 提取DNSKEY和RRSIG记录并按类型和名称组织记录
dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY
rrsigs := make([]*dns.RRSIG, 0)
// 按 (名称, 类型) 组织记录集,用于快速查找
rrSets := make(map[string]map[uint16][]dns.RR) // name -> type -> records
// 定义处理单个记录的函数
processRecord := func(rr dns.RR) {
num := rr.Header().Rrtype
name := rr.Header().Name
// 组织记录集
if _, exists := rrSets[name]; !exists {
rrSets[name] = make(map[uint16][]dns.RR)
}
if _, exists := rrSets[name][num]; !exists {
rrSets[name][num] = make([]dns.RR, 0)
}
rrSets[name][num] = append(rrSets[name][num], rr)
// 特别处理DNSKEY和RRSIG
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
// 一次遍历所有响应部分,同时完成记录收集和组织
for _, rr := range response.Answer {
processRecord(rr)
}
for _, rr := range response.Ns {
processRecord(rr)
}
for _, rr := range response.Extra {
processRecord(rr)
}
// 如果没有RRSIG记录验证失败
if len(rrsigs) == 0 {
return false
}
// 验证所有RRSIG记录
signatureValid := true
// 用于记录已经警告过的DNSKEY tag避免重复警告
warnedKeyTags := make(map[uint16]bool)
for _, rrsig := range rrsigs {
// 查找对应的DNSKEY
dnskey, exists := dnskeys[rrsig.KeyTag]
if !exists {
// 仅当该key_tag尚未警告过时才记录警告
if !warnedKeyTags[rrsig.KeyTag] {
logger.Warn("DNSSEC验证失败找不到对应的DNSKEY", "key_tag", rrsig.KeyTag)
warnedKeyTags[rrsig.KeyTag] = true
}
signatureValid = false
continue
}
// 快速查找需要验证的记录集
name := rrsig.Header().Name
typeCovered := rrsig.TypeCovered
rrset := rrSets[name][typeCovered]
// 验证签名
if len(rrset) > 0 {
err := rrsig.Verify(dnskey, rrset)
if err != nil {
logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag)
signatureValid = false
} else {
logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag)
}
}
}
return signatureValid
}
// updateDomainDNSSECStatus 更新域名的DNSSEC状态
func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) {
// 确保域名是小写
domain = strings.ToLower(domain)
// 更新域名的DNSSEC状态
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
// 更新resolvedDomains中的DNSSEC状态
if entry, exists := s.resolvedDomains[domain]; exists {
entry.DNSSEC = dnssec
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: dnssec,
}
}
// 更新domainDNSSECStatus映射使用单独的锁
s.domainDNSSECStatusMutex.Lock()
s.domainDNSSECStatus[domain] = dnssec
s.domainDNSSECStatusMutex.Unlock()
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
// 先尝试读锁,检查条目是否存在
s.resolvedDomainsMutex.RLock()
entry, exists := s.resolvedDomains[domain]
s.resolvedDomainsMutex.RUnlock()
if exists {
// 使用原子操作更新计数和时间戳
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
// 获取写锁,创建新条目
s.resolvedDomainsMutex.Lock()
// 再次检查,避免竞态条件
if entry, exists := s.resolvedDomains[domain]; exists {
atomic.AddInt64(&entry.Count, 1)
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now().UnixNano(),
DNSSEC: false,
}
}
s.resolvedDomainsMutex.Unlock()
}
}
// getServerStats 获取服务器统计信息,如果不存在则创建
func (s *Server) getServerStats(server string) *ServerStats {
s.serverStatsMutex.RLock()
stats, exists := s.serverStats[server]
s.serverStatsMutex.RUnlock()
if exists {
return stats
}
s.serverStatsMutex.Lock()
defer s.serverStatsMutex.Unlock()
if stats, exists := s.serverStats[server]; exists {
return stats
}
stats = &ServerStats{
SuccessCount: 0,
FailureCount: 0,
LastResponse: time.Now(),
ResponseTime: 0,
ConnectionSpeed: 0,
}
s.serverStats[server] = stats
return stats
}
// updateServerStats 更新服务器统计信息
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
stats := s.getServerStats(server)
// 使用原子操作更新成功和失败计数
if success {
successCount := atomic.AddInt64(&stats.SuccessCount, 1)
// 只在需要更新平均响应时间时获取锁
s.serverStatsMutex.Lock()
stats.LastResponse = time.Now()
// 更新平均响应时间(简单移动平均)
if successCount == 1 {
// 第一次成功,直接使用当前响应时间
stats.ResponseTime = rtt
} else {
// 使用纳秒进行计算以避免类型不匹配
prevTotal := stats.ResponseTime.Nanoseconds() * (successCount - 1)
newTotal := prevTotal + rtt.Nanoseconds()
stats.ResponseTime = time.Duration(newTotal / successCount)
}
s.serverStatsMutex.Unlock()
} else {
atomic.AddInt64(&stats.FailureCount, 1)
// 只更新LastResponse时获取锁
s.serverStatsMutex.Lock()
stats.LastResponse = time.Now()
s.serverStatsMutex.Unlock()
}
}
// selectWeightedRandomServer 加权随机选择服务器
func (s *Server) selectWeightedRandomServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
type serverWeight struct {
server string
weight int64
responseTime time.Duration
successCount int64
failureCount int64
}
var totalWeight int64
var totalResponseTime time.Duration
var validServers int
var currentWeight int64
serversInfo := make([]serverWeight, len(servers))
for i, server := range servers {
stats := s.getServerStats(server)
serversInfo[i] = serverWeight{
server: server,
responseTime: stats.ResponseTime,
successCount: atomic.LoadInt64(&stats.SuccessCount),
failureCount: atomic.LoadInt64(&stats.FailureCount),
}
if stats.ResponseTime > 0 {
totalResponseTime += stats.ResponseTime
validServers++
}
}
var avgResponseTime time.Duration
if validServers > 0 {
avgResponseTime = totalResponseTime / time.Duration(validServers)
} else {
avgResponseTime = 1 * time.Second
}
var randomGen = rand.New(rand.NewSource(time.Now().UnixNano()))
for i := range serversInfo {
baseWeight := serversInfo[i].successCount - serversInfo[i].failureCount*2
if baseWeight < 1 {
baseWeight = 1
}
var responseFactor int64 = 100
if serversInfo[i].responseTime > 0 {
if serversInfo[i].responseTime < avgResponseTime {
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
if factor > 200 {
factor = 200
}
responseFactor = factor
} else {
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
if factor < 50 {
factor = 50
}
responseFactor = factor
}
}
finalWeight := (baseWeight * responseFactor) / 100
if finalWeight < 1 {
finalWeight = 1
}
serversInfo[i].weight = finalWeight
totalWeight += finalWeight
}
random := randomGen.Int63n(totalWeight)
for _, sw := range serversInfo {
currentWeight += sw.weight
if random < currentWeight {
return sw.server
}
}
// 兜底返回第一个服务器
return servers[0]
}
// measureServerSpeed 测量服务器TCP连接速度
func (s *Server) measureServerSpeed(server string) time.Duration {
addr := server
if !strings.Contains(server, ":") {
addr = server + ":53"
}
startTime := time.Now()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
return 2 * time.Second
}
defer conn.Close()
connTime := time.Since(startTime)
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
s.serverStatsMutex.Unlock()
return connTime
}
// selectFastestServer 选择连接速度最快的服务器
func (s *Server) selectFastestServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 并行测量所有服务器的速度
type speedResult struct {
server string
speed time.Duration
}
results := make(chan speedResult, len(servers))
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(srv string) {
defer wg.Done()
speed := s.measureServerSpeed(srv)
results <- speedResult{srv, speed}
}(server)
}
// 等待所有测量完成
go func() {
wg.Wait()
close(results)
}()
// 找出最快的服务器
var fastestServer string
var fastestSpeed time.Duration = 2 * time.Second
for result := range results {
if result.speed < fastestSpeed {
fastestSpeed = result.speed
fastestServer = result.server
}
}
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
if fastestServer == "" {
fastestServer = servers[0]
}
return fastestServer
}
// calculateAvgResponseTime 计算平均响应时间
func calculateAvgResponseTime(totalResponseTime int64, queries int64) float64 {
if queries <= 0 {
return 0
}
avg := float64(totalResponseTime) / float64(queries)
avg = float64(math.Round(avg))
// 限制平均响应时间的范围
if avg > 60000 {
avg = 60000
}
return avg
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
// 创建日志记录
log := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
DNSSEC: dnssec,
EDNS: edns,
DNSServer: dnsServer,
DNSSECServer: dnssecServer,
Answers: answers,
ResponseCode: responseCode,
}
// 发送到日志处理通道(非阻塞)
select {
case s.logChannel <- log:
// 日志发送成功
default:
// 通道已满,丢弃日志以避免阻塞请求处理
logger.Warn("日志通道已满,丢弃一条日志记录")
}
}
// GetStartTime 获取服务器启动时间
func (s *Server) GetStartTime() time.Time {
return s.startTime
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 复制查询类型统计
queryTypesCopy := make(map[string]int64)
for k, v := range s.stats.QueryTypes {
queryTypesCopy[k] = v
}
// 复制来源IP统计
sourceIPsCopy := make(map[string]bool)
for ip := range s.stats.SourceIPs {
sourceIPsCopy[ip] = true
}
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
AvgResponseTime: s.stats.AvgResponseTime,
TotalResponseTime: s.stats.TotalResponseTime,
QueryTypes: queryTypesCopy,
SourceIPs: sourceIPsCopy,
CpuUsage: s.stats.CpuUsage,
DNSSECQueries: s.stats.DNSSECQueries,
DNSSECSuccess: s.stats.DNSSECSuccess,
DNSSECFailed: s.stats.DNSSECFailed,
DNSSECEnabled: s.stats.DNSSECEnabled,
}
}
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
// 确保偏移量和限制值合理
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 100 // 默认返回100条日志
}
// 创建日志副本用于过滤和排序
var logsCopy []QueryLog
// 先过滤日志
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
logsCopy = append(logsCopy, log)
}
// 排序日志
if sortField != "" {
sort.Slice(logsCopy, func(i, j int) bool {
var a, b interface{}
switch sortField {
case "time":
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
case "clientIp":
a = logsCopy[i].ClientIP
b = logsCopy[j].ClientIP
case "domain":
a = logsCopy[i].Domain
b = logsCopy[j].Domain
case "responseTime":
a = logsCopy[i].ResponseTime
b = logsCopy[j].ResponseTime
case "blockRule":
a = logsCopy[i].BlockRule
b = logsCopy[j].BlockRule
default:
// 默认按时间排序
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
}
// 根据排序方向比较
if sortDirection == "asc" {
return compareValues(a, b) < 0
}
return compareValues(a, b) > 0
})
}
// 计算返回范围
start := offset
end := offset + limit
if end > len(logsCopy) {
end = len(logsCopy)
}
if start >= len(logsCopy) {
return []QueryLog{} // 没有数据,返回空切片
}
return logsCopy[start:end]
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
switch v1 := a.(type) {
case time.Time:
v2 := b.(time.Time)
if v1.Before(v2) {
return -1
}
if v1.After(v2) {
return 1
}
return 0
case string:
v2 := b.(string)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
case int64:
v2 := b.(int64)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
default:
return 0
}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
return len(s.queryLogs)
}
// GetQueryStats 获取查询统计信息
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 计算统计数据
return map[string]interface{}{
"totalQueries": s.stats.Queries,
"blockedQueries": s.stats.Blocked,
"allowedQueries": s.stats.Allowed,
"errorQueries": s.stats.Errors,
"avgResponseTime": s.stats.AvgResponseTime,
"activeIPs": len(s.stats.SourceIPs),
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 计算30天前的时间戳
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix()
// 转换为切片并过滤最近30天的数据
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
// 只包含最近30天的数据
if entry.LastSeen >= thirtyDaysAgo {
domains = append(domains, *entry)
}
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 计算30天前的时间戳
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour).Unix()
// 转换为切片并过滤最近30天的数据
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
// 只包含最近30天的数据
if entry.LastSeen >= thirtyDaysAgo {
domains = append(domains, *entry)
}
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen > domains[j].LastSeen
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopClients 获取TOP客户端列表
func (s *Server) GetTopClients(limit int) []ClientStats {
s.clientStatsMutex.RLock()
defer s.clientStatsMutex.RUnlock()
// 转换为切片
clients := make([]ClientStats, 0, len(s.clientStats))
for _, entry := range s.clientStats {
clients = append(clients, *entry)
}
// 按请求次数排序
sort.Slice(clients, func(i, j int) bool {
return clients[i].Count > clients[j].Count
})
// 返回限制数量
if len(clients) > limit {
return clients[:limit]
}
return clients
}
// GetHourlyStats 获取每小时统计数据
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}
// GetDailyStats 获取每日统计数据
func (s *Server) GetDailyStats() map[string]int64 {
s.dailyStatsMutex.RLock()
defer s.dailyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.dailyStats {
result[k] = v
}
return result
}
// GetMonthlyStats 获取每月统计数据
func (s *Server) GetMonthlyStats() map[string]int64 {
s.monthlyStatsMutex.RLock()
defer s.monthlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.monthlyStats {
result[k] = v
}
return result
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 检查IPv4内网地址
if ipv4 := parsedIP.To4(); ipv4 != nil {
// 10.0.0.0/8
if ipv4[0] == 10 {
return true
}
// 172.16.0.0/12
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
return true
}
// 192.168.0.0/16
if ipv4[0] == 192 && ipv4[1] == 168 {
return true
}
// 127.0.0.0/8 (localhost)
if ipv4[0] == 127 {
return true
}
// 169.254.0.0/16 (链路本地地址)
if ipv4[0] == 169 && ipv4[1] == 254 {
return true
}
return false
}
// 检查IPv6内网地址
// ::1/128 (localhost)
if parsedIP.IsLoopback() {
return true
}
// fc00::/7 (唯一本地地址)
if parsedIP[0]&0xfc == 0xfc {
return true
}
// fe80::/10 (链路本地地址)
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
return true
}
return false
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
// 检查文件是否存在
data, err := ioutil.ReadFile("data/stats.json")
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
// 只恢复有效数据,避免破坏统计关系
s.stats.Queries += statsData.Stats.Queries
s.stats.Blocked += statsData.Stats.Blocked
s.stats.Allowed += statsData.Stats.Allowed
s.stats.Errors += statsData.Stats.Errors
s.stats.TotalResponseTime += statsData.Stats.TotalResponseTime
s.stats.DNSSECQueries += statsData.Stats.DNSSECQueries
s.stats.DNSSECSuccess += statsData.Stats.DNSSECSuccess
s.stats.DNSSECFailed += statsData.Stats.DNSSECFailed
// 重新计算平均响应时间,确保一致性
if s.stats.Queries > 0 {
s.stats.AvgResponseTime = float64(s.stats.TotalResponseTime) / float64(s.stats.Queries)
// 限制平均响应时间的范围,避免显示异常大的值
if s.stats.AvgResponseTime > 60000 {
s.stats.AvgResponseTime = 60000
}
}
// 合并查询类型统计
for k, v := range statsData.Stats.QueryTypes {
s.stats.QueryTypes[k] += v
}
// 合并来源IP统计
for ip := range statsData.Stats.SourceIPs {
s.stats.SourceIPs[ip] = true
}
// 确保使用当前配置中的EnableDNSSEC值
s.stats.DNSSECEnabled = s.config.EnableDNSSEC
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
s.dailyStatsMutex.Lock()
if statsData.DailyStats != nil {
s.dailyStats = statsData.DailyStats
}
s.dailyStatsMutex.Unlock()
s.monthlyStatsMutex.Lock()
if statsData.MonthlyStats != nil {
s.monthlyStats = statsData.MonthlyStats
}
s.monthlyStatsMutex.Unlock()
// 加载客户端统计数据
s.clientStatsMutex.Lock()
if statsData.ClientStats != nil {
s.clientStats = statsData.ClientStats
}
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
// 加载查询日志
s.loadQueryLogs()
}
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
// 获取绝对路径
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 构建查询日志文件路径
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
// 检查文件是否存在
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
return
}
// 读取文件内容
data, err := ioutil.ReadFile(queryLogPath)
if err != nil {
logger.Error("读取查询日志文件失败", "error", err)
return
}
// 解析数据
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
logger.Error("解析查询日志失败", "error", err)
return
}
// 更新查询日志
s.queryLogsMutex.Lock()
s.queryLogs = logs
// 确保日志数量不超过限制
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
s.queryLogsMutex.Unlock()
logger.Info("查询日志加载成功", "count", len(logs))
}
// processLogs 异步处理日志记录
func (s *Server) processLogs() {
for {
select {
case logEntry, ok := <-s.logChannel:
if !ok {
// 通道关闭,退出循环
return
}
// 加锁保护queryLogs
s.queryLogsMutex.Lock()
// 如果日志数量超过最大限制,删除最旧的日志
if len(s.queryLogs) >= s.maxQueryLogs {
// 使用切片操作保留最新的日志,避免复制整个切片
// 保留最新的s.maxQueryLogs-1条日志然后添加新日志
s.queryLogs = s.queryLogs[len(s.queryLogs)-s.maxQueryLogs+1:]
}
// 直接添加新日志
s.queryLogs = append(s.queryLogs, logEntry)
// 解锁
s.queryLogsMutex.Unlock()
case <-s.ctx.Done():
// 上下文取消,退出循环
return
}
}
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
s.dailyStatsMutex.RLock()
statsData.DailyStats = make(map[string]int64)
for k, v := range s.dailyStats {
statsData.DailyStats[k] = v
}
s.dailyStatsMutex.RUnlock()
s.monthlyStatsMutex.RLock()
statsData.MonthlyStats = make(map[string]int64)
for k, v := range s.monthlyStats {
statsData.MonthlyStats[k] = v
}
s.monthlyStatsMutex.RUnlock()
// 复制客户端统计数据
s.clientStatsMutex.RLock()
statsData.ClientStats = make(map[string]*ClientStats)
for k, v := range s.clientStats {
statsData.ClientStats[k] = v
}
s.clientStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(statsFilePath, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
return
}
logger.Info("统计数据保存成功", "file", statsFilePath)
// 保存查询日志到文件
s.saveQueryLogs(statsDir)
}
// saveQueryLogs 保存查询日志到文件
func (s *Server) saveQueryLogs(dataDir string) {
// 构建查询日志文件路径
queryLogPath := filepath.Join(dataDir, "querylog.json")
// 获取查询日志数据
s.queryLogsMutex.RLock()
logsCopy := make([]QueryLog, len(s.queryLogs))
copy(logsCopy, s.queryLogs)
s.queryLogsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
if err != nil {
logger.Error("序列化查询日志失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(queryLogPath, jsonData, 0644)
if err != nil {
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
return
}
logger.Info("查询日志保存成功", "file", queryLogPath)
}
// startCpuUsageMonitor 启动CPU使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()
// 初始化
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// 存储上一次的CPU时间统计
var prevIdle, prevTotal uint64
for {
select {
case <-ticker.C:
// 获取真实的系统级CPU使用率
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
if err != nil {
// 如果获取失败,使用默认值
cpuUsage = 0.0
logger.Error("获取系统CPU使用率失败", "error", err)
}
s.updateStats(func(stats *Stats) {
stats.CpuUsage = cpuUsage
})
case <-s.ctx.Done():
return
}
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息
file, err := os.Open("/proc/stat")
if err != nil {
return 0, err
}
defer file.Close()
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
if err != nil {
return 0, err
}
// 计算总的CPU时间
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
idle := cpuIdle + cpuIowait
// 第一次调用时,只初始化值,不计算使用率
if *prevTotal == 0 || *prevIdle == 0 {
*prevIdle = idle
*prevTotal = total
return 0, nil
}
// 计算CPU使用率
idleDelta := idle - *prevIdle
totalDelta := total - *prevTotal
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
// 更新上一次的值
*prevIdle = idle
*prevTotal = total
return utilization, nil
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.SaveInterval <= 0 {
return
}
// 初始化定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}