3214 lines
91 KiB
Go
3214 lines
91 KiB
Go
package dns
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io/ioutil"
|
||
"math"
|
||
"math/rand"
|
||
"net"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"dns-server/config"
|
||
"dns-server/gfw"
|
||
"dns-server/logger"
|
||
"dns-server/shield"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// 确保DNS服务器地址包含端口号,默认添加53端口
|
||
func normalizeDNSServerAddress(address string) string {
|
||
// 检查地址是否已经包含端口号
|
||
if _, _, err := net.SplitHostPort(address); err != nil {
|
||
// 如果没有端口号,添加默认的53端口
|
||
return net.JoinHostPort(address, "53")
|
||
}
|
||
// 已经有端口号,直接返回
|
||
return address
|
||
}
|
||
|
||
// BlockedDomain 屏蔽域名统计
|
||
type BlockedDomain struct {
|
||
Domain string
|
||
Count int64
|
||
LastSeen int64
|
||
DNSSEC bool // 是否使用了DNSSEC
|
||
}
|
||
|
||
// ClientStats 客户端统计
|
||
|
||
type ClientStats struct {
|
||
IP string
|
||
Count int64
|
||
LastSeen int64
|
||
}
|
||
|
||
// DNSAnswer DNS解析记录
|
||
type DNSAnswer struct {
|
||
Type string `json:"type"` // 记录类型
|
||
Value string `json:"value"` // 记录值
|
||
TTL uint32 `json:"ttl"` // 生存时间
|
||
}
|
||
|
||
// QueryLog 查询日志记录
|
||
type QueryLog struct {
|
||
Timestamp time.Time `json:"timestamp"` // 查询时间
|
||
ClientIP string `json:"clientIP"` // 客户端IP
|
||
Location string `json:"location"` // IP地理位置(国家 城市)
|
||
Domain string `json:"domain"` // 查询域名
|
||
QueryType string `json:"queryType"` // 查询类型
|
||
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
|
||
Result string `json:"result"` // 查询结果(allowed, blocked, error)
|
||
BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽)
|
||
BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽)
|
||
FromCache bool `json:"fromCache"` // 是否来自缓存
|
||
DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC
|
||
EDNS bool `json:"edns"` // 是否使用了EDNS
|
||
DNSServer string `json:"dnsServer"` // 使用的DNS服务器
|
||
DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器
|
||
Answers []DNSAnswer `json:"answers"` // 解析记录
|
||
ResponseCode int `json:"responseCode"` // DNS响应代码
|
||
}
|
||
|
||
// StatsData 用于持久化的统计数据结构
|
||
type StatsData struct {
|
||
Stats *Stats `json:"stats"`
|
||
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
|
||
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
|
||
ClientStats map[string]*ClientStats `json:"clientStats"`
|
||
HourlyStats map[string]int64 `json:"hourlyStats"`
|
||
DailyStats map[string]int64 `json:"dailyStats"`
|
||
MonthlyStats map[string]int64 `json:"monthlyStats"`
|
||
LastSaved time.Time `json:"lastSaved"`
|
||
}
|
||
|
||
// ServerStats 服务器统计信息
|
||
type ServerStats struct {
|
||
SuccessCount int64 // 成功查询次数
|
||
FailureCount int64 // 失败查询次数
|
||
LastResponse time.Time // 最后响应时间
|
||
ResponseTime time.Duration // 平均响应时间
|
||
ConnectionSpeed time.Duration // TCP连接速度
|
||
}
|
||
|
||
// Server DNS服务器
|
||
type Server struct {
|
||
config *config.DNSConfig
|
||
shieldConfig *config.ShieldConfig
|
||
shieldManager *shield.ShieldManager
|
||
gfwConfig *config.GFWListConfig
|
||
gfwManager *gfw.GFWListManager
|
||
server *dns.Server
|
||
tcpServer *dns.Server
|
||
resolver *dns.Client
|
||
ctx context.Context
|
||
cancel context.CancelFunc
|
||
statsMutex sync.Mutex
|
||
stats *Stats
|
||
blockedDomainsMutex sync.RWMutex
|
||
blockedDomains map[string]*BlockedDomain
|
||
resolvedDomainsMutex sync.RWMutex
|
||
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
|
||
clientStatsMutex sync.RWMutex
|
||
clientStats map[string]*ClientStats // 用于记录客户端统计
|
||
hourlyStatsMutex sync.RWMutex
|
||
hourlyStats map[string]int64 // 按小时统计屏蔽数量
|
||
dailyStatsMutex sync.RWMutex
|
||
dailyStats map[string]int64 // 按天统计屏蔽数量
|
||
monthlyStatsMutex sync.RWMutex
|
||
monthlyStats map[string]int64 // 按月统计屏蔽数量
|
||
queryLogsMutex sync.RWMutex
|
||
queryLogs []QueryLog // 查询日志列表
|
||
maxQueryLogs int // 最大保存日志数量
|
||
logChannel chan QueryLog // 日志处理通道
|
||
saveTicker *time.Ticker // 用于定时保存数据
|
||
startTime time.Time // 服务器启动时间
|
||
saveDone chan struct{} // 用于通知保存协程停止
|
||
stopped bool // 服务器是否已经停止
|
||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||
|
||
// DNS查询缓存
|
||
DnsCache *DNSCache // DNS响应缓存
|
||
|
||
// 域名DNSSEC状态映射表
|
||
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
|
||
domainDNSSECStatusMutex sync.RWMutex // 保护域名DNSSEC状态映射的互斥锁
|
||
|
||
// 上游服务器状态跟踪
|
||
serverStats map[string]*ServerStats // 服务器地址到状态的映射
|
||
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
|
||
|
||
// DNSSEC专用服务器映射,用于快速查找
|
||
dnssecServerMap map[string]bool // DNSSEC专用服务器地址到布尔值的映射
|
||
|
||
// DNS客户端实例池,用于并行查询
|
||
clientPool sync.Pool // 存储*dns.Client实例
|
||
}
|
||
|
||
// Stats DNS服务器统计信息
|
||
type Stats struct {
|
||
Queries int64
|
||
Blocked int64
|
||
Allowed int64
|
||
Errors int64
|
||
LastQuery time.Time
|
||
AvgResponseTime float64 // 平均响应时间(ms)
|
||
TotalResponseTime int64 // 总响应时间
|
||
QueryTypes map[string]int64 // 查询类型统计
|
||
SourceIPs map[string]bool // 活跃来源IP
|
||
CpuUsage float64 // CPU使用率(%)
|
||
DNSSECQueries int64 // DNSSEC查询总数
|
||
DNSSECSuccess int64 // DNSSEC验证成功数
|
||
DNSSECFailed int64 // DNSSEC验证失败数
|
||
DNSSECEnabled bool // 是否启用了DNSSEC
|
||
}
|
||
|
||
// NewServer 创建DNS服务器实例
|
||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager, gfwConfig *config.GFWListConfig, gfwManager *gfw.GFWListManager) *Server {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
|
||
// 从配置中读取DNS缓存TTL值(分钟)
|
||
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
|
||
|
||
server := &Server{
|
||
config: config,
|
||
shieldConfig: shieldConfig,
|
||
shieldManager: shieldManager,
|
||
gfwConfig: gfwConfig,
|
||
gfwManager: gfwManager,
|
||
resolver: &dns.Client{
|
||
Net: "udp",
|
||
UDPSize: 4096, // 增加UDP缓冲区大小,支持更大的DNSSEC响应
|
||
},
|
||
ctx: ctx,
|
||
cancel: cancel,
|
||
startTime: time.Now(), // 记录服务器启动时间
|
||
stats: &Stats{
|
||
Queries: 0,
|
||
Blocked: 0,
|
||
Allowed: 0,
|
||
Errors: 0,
|
||
AvgResponseTime: 0,
|
||
TotalResponseTime: 0,
|
||
QueryTypes: make(map[string]int64),
|
||
SourceIPs: make(map[string]bool),
|
||
CpuUsage: 0,
|
||
DNSSECQueries: 0,
|
||
DNSSECSuccess: 0,
|
||
DNSSECFailed: 0,
|
||
DNSSECEnabled: config.EnableDNSSEC,
|
||
},
|
||
blockedDomains: make(map[string]*BlockedDomain),
|
||
resolvedDomains: make(map[string]*BlockedDomain),
|
||
clientStats: make(map[string]*ClientStats),
|
||
hourlyStats: make(map[string]int64),
|
||
dailyStats: make(map[string]int64),
|
||
monthlyStats: make(map[string]int64),
|
||
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片,容量1000
|
||
maxQueryLogs: 10000, // 最大保存10000条日志
|
||
logChannel: make(chan QueryLog, 1000), // 日志处理通道,缓冲区大小1000
|
||
saveDone: make(chan struct{}),
|
||
stopped: false, // 初始化为未停止状态
|
||
|
||
// DNS查询缓存初始化
|
||
DnsCache: NewDNSCache(cacheTTL),
|
||
// 初始化域名DNSSEC状态映射表
|
||
domainDNSSECStatus: make(map[string]bool),
|
||
// 初始化服务器状态跟踪
|
||
serverStats: make(map[string]*ServerStats),
|
||
// 初始化DNSSEC专用服务器映射
|
||
dnssecServerMap: make(map[string]bool),
|
||
// 初始化DNS客户端实例池
|
||
clientPool: sync.Pool{
|
||
New: func() interface{} {
|
||
return &dns.Client{
|
||
Net: "udp",
|
||
UDPSize: 4096,
|
||
Timeout: 5 * time.Second, // 默认超时时间,会在使用时覆盖
|
||
}
|
||
},
|
||
},
|
||
}
|
||
|
||
// 加载已保存的统计数据
|
||
server.loadStatsData()
|
||
|
||
return server
|
||
|
||
}
|
||
|
||
// Start 启动DNS服务器
|
||
func (s *Server) Start() error {
|
||
// 重新初始化上下文和取消函数
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
s.ctx = ctx
|
||
s.cancel = cancel
|
||
|
||
// 重新初始化saveDone通道
|
||
s.saveDone = make(chan struct{})
|
||
|
||
// 重置stopped标志
|
||
s.stoppedMutex.Lock()
|
||
s.stopped = false
|
||
s.stoppedMutex.Unlock()
|
||
|
||
// 更新服务器启动时间
|
||
s.startTime = time.Now()
|
||
|
||
s.server = &dns.Server{
|
||
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
|
||
Net: "udp",
|
||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||
}
|
||
|
||
// 保存TCP服务器实例,以便在Stop方法中关闭
|
||
s.tcpServer = &dns.Server{
|
||
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
|
||
Net: "tcp",
|
||
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
||
}
|
||
|
||
// 启动CPU使用率监控
|
||
go s.startCpuUsageMonitor()
|
||
|
||
// 启动自动保存功能
|
||
go s.startAutoSave()
|
||
|
||
// 更新DNSSEC专用服务器映射
|
||
s.updateDNSSECServerMap()
|
||
|
||
// 启动日志处理协程
|
||
go s.processLogs()
|
||
|
||
// 启动统计数据定期重置功能(每24小时)
|
||
go func() {
|
||
ticker := time.NewTicker(24 * time.Hour)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
s.resetStats()
|
||
case <-s.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 启动UDP服务
|
||
go func() {
|
||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||
if err := s.server.ListenAndServe(); err != nil {
|
||
logger.Error("DNS UDP服务器启动失败", "error", err)
|
||
s.cancel()
|
||
}
|
||
}()
|
||
|
||
// 启动TCP服务
|
||
go func() {
|
||
logger.Info(fmt.Sprintf("DNS TCP服务器启动,监听端口: %d", s.config.Port))
|
||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||
logger.Error("DNS TCP服务器启动失败", "error", err)
|
||
s.cancel()
|
||
}
|
||
}()
|
||
|
||
// 等待停止信号
|
||
<-s.ctx.Done()
|
||
return nil
|
||
}
|
||
|
||
// resetStats 重置统计数据
|
||
func (s *Server) resetStats() {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
|
||
// 只重置累计值,保留配置相关值
|
||
s.stats.TotalResponseTime = 0
|
||
s.stats.AvgResponseTime = 0
|
||
s.stats.Queries = 0
|
||
s.stats.Blocked = 0
|
||
s.stats.Allowed = 0
|
||
s.stats.Errors = 0
|
||
s.stats.DNSSECQueries = 0
|
||
s.stats.DNSSECSuccess = 0
|
||
s.stats.DNSSECFailed = 0
|
||
s.stats.QueryTypes = make(map[string]int64)
|
||
s.stats.SourceIPs = make(map[string]bool)
|
||
|
||
logger.Info("统计数据已重置")
|
||
}
|
||
|
||
// Stop 停止DNS服务器
|
||
func (s *Server) Stop() {
|
||
// 检查服务器是否已经停止
|
||
s.stoppedMutex.Lock()
|
||
if s.stopped {
|
||
s.stoppedMutex.Unlock()
|
||
return // 服务器已经停止,直接返回
|
||
}
|
||
// 标记服务器为已停止状态
|
||
s.stopped = true
|
||
s.stoppedMutex.Unlock()
|
||
|
||
// 发送停止信号给保存协程
|
||
close(s.saveDone)
|
||
|
||
// 最后保存一次数据
|
||
s.saveStatsData()
|
||
|
||
// 停止服务器
|
||
s.cancel()
|
||
if s.server != nil {
|
||
s.server.Shutdown()
|
||
}
|
||
if s.tcpServer != nil {
|
||
s.tcpServer.Shutdown()
|
||
}
|
||
logger.Info("DNS服务器已停止")
|
||
}
|
||
|
||
// handleDNSRequest 处理DNS请求
|
||
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||
startTime := time.Now()
|
||
|
||
// 获取来源IP
|
||
sourceIP := w.RemoteAddr().String()
|
||
// 提取IP地址部分,去掉端口
|
||
if strings.HasPrefix(sourceIP, "[") {
|
||
// IPv6地址格式: [::1]:53
|
||
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
|
||
sourceIP = sourceIP[1:idx] // 去掉方括号
|
||
}
|
||
} else {
|
||
// IPv4地址格式: 127.0.0.1:53
|
||
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
|
||
sourceIP = sourceIP[:idx]
|
||
}
|
||
}
|
||
|
||
// 更新来源IP统计和Queries计数器
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Queries++
|
||
stats.LastQuery = time.Now()
|
||
stats.SourceIPs[sourceIP] = true
|
||
})
|
||
|
||
// 更新客户端统计
|
||
s.updateClientStats(sourceIP)
|
||
|
||
// 获取查询域名和类型
|
||
var domain string
|
||
var queryType string
|
||
var qType uint16
|
||
if len(r.Question) > 0 {
|
||
domain = r.Question[0].Name
|
||
// 移除末尾的点
|
||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||
domain = domain[:len(domain)-1]
|
||
}
|
||
// 获取查询类型
|
||
queryType = dns.TypeToString[r.Question[0].Qtype]
|
||
qType = r.Question[0].Qtype
|
||
// 更新查询类型统计
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.QueryTypes[queryType]++
|
||
})
|
||
|
||
// 检查是否是AAAA记录查询且IPv6解析已禁用
|
||
if qType == dns.TypeAAAA && !s.config.EnableIPv6 {
|
||
// 返回NXDOMAIN响应(域名不存在)
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
response.SetRcode(r, dns.RcodeNameError)
|
||
w.WriteMsg(response)
|
||
|
||
// 更新统计信息
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
// 添加防御性编程,确保Queries大于0
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
// 限制平均响应时间的范围,避免显示异常大的值
|
||
if stats.AvgResponseTime > 60000 {
|
||
stats.AvgResponseTime = 60000
|
||
}
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeNameError)
|
||
logger.Debug("IPv6解析已禁用,拒绝AAAA记录查询", "domain", domain)
|
||
return
|
||
}
|
||
}
|
||
|
||
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
|
||
|
||
// 只处理递归查询
|
||
if r.RecursionDesired == false {
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
|
||
response.SetRcode(r, dns.RcodeRefused)
|
||
w.WriteMsg(response)
|
||
|
||
// 计算实际响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
}
|
||
})
|
||
|
||
// 添加查询日志
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused)
|
||
return
|
||
}
|
||
|
||
// 检查hosts文件是否有匹配
|
||
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
|
||
s.handleHostsResponse(w, r, ip)
|
||
// 计算实际响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
}
|
||
})
|
||
|
||
// 该方法内部未直接调用addQueryLog,而是在handleDNSRequest中处理
|
||
return
|
||
}
|
||
|
||
// 检查是否为GFWList域名(仅当GFWList功能启用时)
|
||
if s.gfwConfig.Enabled && s.gfwManager != nil && s.gfwManager.IsMatch(domain) {
|
||
s.handleGFWListResponse(w, r, domain)
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
}
|
||
})
|
||
|
||
// 添加查询日志 - GFWList域名
|
||
gfwAnswers := []DNSAnswer{}
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "gfwlist", "", "", false, false, true, "GFWList", "无", gfwAnswers, dns.RcodeSuccess)
|
||
return
|
||
}
|
||
|
||
// 检查是否被屏蔽
|
||
if s.shieldManager.IsBlocked(domain) {
|
||
// 获取屏蔽详情
|
||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||
blockRule, _ := blockDetails["blockRule"].(string)
|
||
blockType, _ := blockDetails["blockRuleType"].(string)
|
||
|
||
s.handleBlockedResponse(w, r, domain)
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
}
|
||
})
|
||
|
||
// 添加查询日志 - 被屏蔽域名
|
||
blockedAnswers := []DNSAnswer{}
|
||
// 根据屏蔽方法确定响应代码
|
||
blockedRcode := dns.RcodeNameError // 默认NXDOMAIN
|
||
if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" {
|
||
blockedRcode = dns.RcodeRefused
|
||
} else if blockMethod == "emptyIP" || blockMethod == "customIP" {
|
||
blockedRcode = dns.RcodeSuccess
|
||
}
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode)
|
||
return
|
||
}
|
||
|
||
// 检查缓存中是否有响应(优先查找带DNSSEC的缓存项)
|
||
var cachedResponse *dns.Msg
|
||
var found bool
|
||
var cachedDNSSEC bool
|
||
|
||
// 1. 首先检查是否有普通缓存项
|
||
if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, qType); tempFound {
|
||
cachedResponse = tempResponse
|
||
found = tempFound
|
||
cachedDNSSEC = s.hasDNSSECRecords(tempResponse)
|
||
}
|
||
|
||
// 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项,
|
||
// 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录
|
||
// (这里可以进一步优化,比如在缓存中标记DNSSEC状态,快速查找)
|
||
if s.config.EnableDNSSEC && !cachedDNSSEC {
|
||
// 目前的缓存实现不支持按DNSSEC状态查找,所以这里暂时跳过
|
||
// 后续可以考虑改进缓存实现,添加DNSSEC状态标记
|
||
}
|
||
|
||
if found {
|
||
// 缓存命中,直接返回缓存的响应
|
||
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
|
||
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
|
||
cachedResponseCopy.Compress = true
|
||
|
||
// 如果客户端请求包含EDNS记录,确保响应也包含EDNS
|
||
if opt := r.IsEdns0(); opt != nil {
|
||
// 检查响应是否已经包含EDNS记录
|
||
if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil {
|
||
// 添加EDNS记录,使用客户端的UDP缓冲区大小
|
||
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
|
||
} else {
|
||
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
|
||
if respOpt.UDPSize() > opt.UDPSize() {
|
||
// 移除现有的EDNS记录
|
||
for i := range cachedResponseCopy.Extra {
|
||
if cachedResponseCopy.Extra[i] == respOpt {
|
||
cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
// 添加新的EDNS记录,使用客户端的UDP缓冲区大小
|
||
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
|
||
}
|
||
}
|
||
}
|
||
|
||
w.WriteMsg(cachedResponseCopy)
|
||
|
||
// 计算响应时间
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
}
|
||
})
|
||
|
||
// 如果缓存响应包含DNSSEC记录,更新DNSSEC查询计数
|
||
if cachedDNSSEC {
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
// 缓存响应视为DNSSEC成功
|
||
stats.DNSSECSuccess++
|
||
})
|
||
}
|
||
|
||
// 从缓存响应中提取解析记录
|
||
cachedAnswers := []DNSAnswer{}
|
||
if cachedResponse != nil {
|
||
for _, rr := range cachedResponse.Answer {
|
||
cachedAnswers = append(cachedAnswers, DNSAnswer{
|
||
Type: dns.TypeToString[rr.Header().Rrtype],
|
||
Value: rr.String(),
|
||
TTL: rr.Header().Ttl,
|
||
})
|
||
}
|
||
}
|
||
|
||
// 添加查询日志 - 标记为缓存
|
||
// 从缓存响应中获取响应代码
|
||
cacheRcode := dns.RcodeSuccess // 默认成功
|
||
if cachedResponse != nil {
|
||
cacheRcode = cachedResponse.Rcode
|
||
}
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode)
|
||
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC)
|
||
return
|
||
}
|
||
|
||
// 缓存未命中,处理DNS请求
|
||
var response *dns.Msg
|
||
var rtt time.Duration
|
||
var queryAttempts []string
|
||
var dnsServer string
|
||
var dnssecServer string
|
||
|
||
// 直接查询原始域名
|
||
queryAttempts = append(queryAttempts, domain)
|
||
response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, domain)
|
||
|
||
if response != nil {
|
||
// 如果客户端请求包含EDNS记录,确保响应也包含EDNS
|
||
if opt := r.IsEdns0(); opt != nil {
|
||
// 检查响应是否已经包含EDNS记录
|
||
if respOpt := response.IsEdns0(); respOpt == nil {
|
||
// 添加EDNS记录,使用客户端的UDP缓冲区大小
|
||
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
|
||
} else {
|
||
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
|
||
if respOpt.UDPSize() > opt.UDPSize() {
|
||
// 移除现有的EDNS记录
|
||
for i := range response.Extra {
|
||
if response.Extra[i] == respOpt {
|
||
response.Extra = append(response.Extra[:i], response.Extra[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
// 添加新的EDNS记录,使用客户端的UDP缓冲区大小
|
||
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 写入响应给客户端
|
||
w.WriteMsg(response)
|
||
}
|
||
|
||
// 使用上游服务器的实际响应时间(转换为毫秒)
|
||
responseTime := int64(rtt.Milliseconds())
|
||
// 如果rtt为0(查询失败),则使用本地计算的时间
|
||
if responseTime == 0 {
|
||
responseTime = time.Since(startTime).Milliseconds()
|
||
}
|
||
|
||
// 添加合理性检查,避免异常大的响应时间影响统计
|
||
if responseTime > 60000 { // 超过60秒的响应时间视为异常
|
||
responseTime = 60000
|
||
}
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.TotalResponseTime += responseTime
|
||
// 添加防御性编程,确保Queries大于0
|
||
if stats.Queries > 0 {
|
||
// 平均响应时间 = 总响应时间 / 总解析数量,四舍五入取整
|
||
avg := float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||
stats.AvgResponseTime = float64(math.Round(avg))
|
||
// 限制平均响应时间的范围,避免显示异常大的值
|
||
if stats.AvgResponseTime > 60000 {
|
||
stats.AvgResponseTime = 60000
|
||
}
|
||
}
|
||
})
|
||
|
||
// 检查响应是否包含DNSSEC记录并验证结果
|
||
responseDNSSEC := false
|
||
if response != nil {
|
||
// 使用hasDNSSECRecords函数检查是否包含DNSSEC记录
|
||
responseDNSSEC = s.hasDNSSECRecords(response)
|
||
|
||
// 检查AD标志,确认DNSSEC验证是否成功
|
||
if response.AuthenticatedData {
|
||
responseDNSSEC = true
|
||
}
|
||
|
||
// 更新域名的DNSSEC状态
|
||
if responseDNSSEC {
|
||
s.updateDomainDNSSECStatus(domain, true)
|
||
}
|
||
}
|
||
|
||
// 如果响应成功,缓存结果(增强版缓存存储)
|
||
if response != nil && response.Rcode == dns.RcodeSuccess {
|
||
// 创建响应副本以避免后续修改影响缓存
|
||
responseCopy := response.Copy()
|
||
// 设置合理的TTL,不超过默认的30分钟
|
||
defaultCacheTTL := 30 * time.Minute
|
||
s.DnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
|
||
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
|
||
}
|
||
|
||
// 从响应中提取解析记录
|
||
responseAnswers := []DNSAnswer{}
|
||
if response != nil {
|
||
for _, rr := range response.Answer {
|
||
responseAnswers = append(responseAnswers, DNSAnswer{
|
||
Type: dns.TypeToString[rr.Header().Rrtype],
|
||
Value: rr.String(),
|
||
TTL: rr.Header().Ttl,
|
||
})
|
||
}
|
||
}
|
||
|
||
// 添加查询日志 - 标记为实时
|
||
// 从响应中获取响应代码
|
||
realRcode := dns.RcodeSuccess // 默认成功
|
||
if response != nil {
|
||
realRcode = response.Rcode
|
||
}
|
||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
|
||
}
|
||
|
||
// handleHostsResponse 处理hosts文件匹配的响应
|
||
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
|
||
|
||
if len(r.Question) > 0 {
|
||
q := r.Question[0]
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: q.Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP(ip)
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
|
||
// 记录解析域名统计
|
||
domain := ""
|
||
if len(r.Question) > 0 {
|
||
domain = r.Question[0].Name
|
||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||
domain = domain[:len(domain)-1]
|
||
}
|
||
s.updateResolvedDomainStats(domain)
|
||
}
|
||
|
||
w.WriteMsg(response)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
}
|
||
|
||
// handleGFWListResponse 处理GFWList域名响应
|
||
func (s *Server) handleGFWListResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||
logger.Info("GFWList域名解析", "domain", domain, "client", w.RemoteAddr(), "ip", s.gfwConfig.IP)
|
||
|
||
// 更新解析域名统计
|
||
s.updateResolvedDomainStats(domain)
|
||
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
|
||
if len(r.Question) > 0 {
|
||
q := r.Question[0]
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: q.Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP(s.gfwConfig.IP)
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
|
||
w.WriteMsg(response)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
}
|
||
|
||
// handleBlockedResponse 处理被屏蔽的域名响应
|
||
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
|
||
|
||
// 更新被屏蔽域名统计
|
||
s.updateBlockedDomainStats(domain)
|
||
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
// 不再硬编码RecursionAvailable,使用默认值或上游返回的值
|
||
|
||
// 获取屏蔽方法配置
|
||
blockMethod := "NXDOMAIN" // 默认值
|
||
customBlockIP := "" // 默认值
|
||
|
||
// 从Server结构体的shieldConfig字段获取配置
|
||
if s.shieldConfig != nil {
|
||
blockMethod = s.shieldConfig.BlockMethod
|
||
customBlockIP = s.shieldConfig.CustomBlockIP
|
||
}
|
||
|
||
// 根据屏蔽方法返回不同的响应
|
||
switch blockMethod {
|
||
case "refused":
|
||
// 返回拒绝查询响应
|
||
response.SetRcode(r, dns.RcodeRefused)
|
||
case "emptyIP":
|
||
// 返回空IP响应
|
||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: r.Question[0].Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
answer.A = net.ParseIP("0.0.0.0") // 空IP
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
case "customIP":
|
||
// 返回自定义IP响应
|
||
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
|
||
answer := new(dns.A)
|
||
answer.Hdr = dns.RR_Header{
|
||
Name: r.Question[0].Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: dns.ClassINET,
|
||
Ttl: 300,
|
||
}
|
||
// 使用自定义屏蔽IP
|
||
if customBlockIP != "" {
|
||
answer.A = net.ParseIP(customBlockIP)
|
||
} else {
|
||
// 如果没有配置,使用0.0.0.0
|
||
answer.A = net.ParseIP("0.0.0.0")
|
||
}
|
||
response.Answer = append(response.Answer, answer)
|
||
}
|
||
case "NXDOMAIN", "":
|
||
fallthrough // 默认使用NXDOMAIN
|
||
default:
|
||
// 返回NXDOMAIN响应(域名不存在)
|
||
response.SetRcode(r, dns.RcodeNameError)
|
||
}
|
||
|
||
w.WriteMsg(response)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Blocked++
|
||
})
|
||
}
|
||
|
||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||
// serverResponse 用于存储服务器响应的结构体
|
||
type serverResponse struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
server string
|
||
error error
|
||
}
|
||
|
||
// recordKey 用于唯一标识DNS记录的结构体
|
||
type recordKey struct {
|
||
name string
|
||
rtype uint16
|
||
class uint16
|
||
data string
|
||
}
|
||
|
||
// getRecordKey 获取DNS记录的唯一标识
|
||
func getRecordKey(rr dns.RR) recordKey {
|
||
// 对于同一域名的同一类型记录,只保留一个,选择最长TTL
|
||
// 所以对于A、AAAA、CNAME等记录,只使用name、rtype、class作为键
|
||
// 对于MX记录,还需要考虑Preference字段
|
||
// 对于TXT记录,需要考虑实际文本内容
|
||
// 对于NS记录,需要考虑目标服务器
|
||
|
||
switch rr.Header().Rrtype {
|
||
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypePTR:
|
||
// 对于A、AAAA、CNAME、PTR记录,同一域名只保留一个
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: "",
|
||
}
|
||
case dns.TypeMX:
|
||
// 对于MX记录,同一域名的同一Preference只保留一个
|
||
if mx, ok := rr.(*dns.MX); ok {
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: fmt.Sprintf("%d", mx.Preference),
|
||
}
|
||
}
|
||
case dns.TypeTXT:
|
||
// 对于TXT记录,需要考虑实际文本内容
|
||
if txt, ok := rr.(*dns.TXT); ok {
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: strings.Join(txt.Txt, " "),
|
||
}
|
||
}
|
||
case dns.TypeNS:
|
||
// 对于NS记录,需要考虑目标服务器
|
||
if ns, ok := rr.(*dns.NS); ok {
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: ns.Ns,
|
||
}
|
||
}
|
||
case dns.TypeSOA:
|
||
// 对于SOA记录,同一域名只保留一个
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: "",
|
||
}
|
||
}
|
||
|
||
// 对于其他类型,使用原始rr.String(),但移除TTL部分
|
||
parts := strings.Split(rr.String(), " ")
|
||
if len(parts) >= 5 {
|
||
// 跳过TTL字段(第3个字段)
|
||
data := strings.Join(append(parts[:2], parts[3:]...), " ")
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: data,
|
||
}
|
||
}
|
||
|
||
return recordKey{
|
||
name: rr.Header().Name,
|
||
rtype: rr.Header().Rrtype,
|
||
class: rr.Header().Class,
|
||
data: rr.String(),
|
||
}
|
||
}
|
||
|
||
// mergeResponses 合并多个DNS响应
|
||
func mergeResponses(responses []*dns.Msg) *dns.Msg {
|
||
if len(responses) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 如果只有一个响应,直接返回,避免不必要的合并操作
|
||
if len(responses) == 1 {
|
||
return responses[0].Copy()
|
||
}
|
||
|
||
// 使用第一个响应作为基础
|
||
mergedResponse := responses[0].Copy()
|
||
mergedResponse.Answer = []dns.RR{}
|
||
mergedResponse.Ns = []dns.RR{}
|
||
mergedResponse.Extra = []dns.RR{}
|
||
|
||
// 重置Rcode为成功,除非所有响应都是NXDOMAIN
|
||
mergedResponse.Rcode = dns.RcodeSuccess
|
||
|
||
// 检查是否所有响应都是NXDOMAIN
|
||
allNXDOMAIN := true
|
||
|
||
// 收集所有成功响应的记录
|
||
for _, resp := range responses {
|
||
if resp == nil {
|
||
continue
|
||
}
|
||
|
||
// 如果有任何响应是成功的,就不是allNXDOMAIN
|
||
if resp.Rcode == dns.RcodeSuccess {
|
||
allNXDOMAIN = false
|
||
}
|
||
}
|
||
|
||
// 如果所有响应都是NXDOMAIN,设置合并响应为NXDOMAIN
|
||
if allNXDOMAIN {
|
||
mergedResponse.Rcode = dns.RcodeNameError
|
||
}
|
||
|
||
// 使用map存储唯一记录,选择最长TTL
|
||
// 预分配map容量,减少扩容开销
|
||
answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses))
|
||
nsMap := make(map[recordKey]dns.RR, len(responses[0].Ns)*len(responses))
|
||
extraMap := make(map[recordKey]dns.RR, len(responses[0].Extra)*len(responses))
|
||
|
||
for _, resp := range responses {
|
||
if resp == nil {
|
||
continue
|
||
}
|
||
|
||
// 只合并与最终Rcode匹配的响应记录
|
||
if (mergedResponse.Rcode == dns.RcodeSuccess && resp.Rcode == dns.RcodeSuccess) ||
|
||
(mergedResponse.Rcode == dns.RcodeNameError && resp.Rcode == dns.RcodeNameError) {
|
||
|
||
// 合并Answer部分
|
||
for _, rr := range resp.Answer {
|
||
key := getRecordKey(rr)
|
||
if existing, exists := answerMap[key]; exists {
|
||
// 如果存在相同记录,选择TTL更长的
|
||
if rr.Header().Ttl > existing.Header().Ttl {
|
||
answerMap[key] = rr
|
||
}
|
||
} else {
|
||
answerMap[key] = rr
|
||
}
|
||
}
|
||
|
||
// 合并Ns部分
|
||
for _, rr := range resp.Ns {
|
||
key := getRecordKey(rr)
|
||
if existing, exists := nsMap[key]; exists {
|
||
// 如果存在相同记录,选择TTL更长的
|
||
if rr.Header().Ttl > existing.Header().Ttl {
|
||
nsMap[key] = rr
|
||
}
|
||
} else {
|
||
nsMap[key] = rr
|
||
}
|
||
}
|
||
|
||
// 合并Extra部分
|
||
for _, rr := range resp.Extra {
|
||
// 跳过OPT记录,避免重复
|
||
if rr.Header().Rrtype == dns.TypeOPT {
|
||
continue
|
||
}
|
||
key := getRecordKey(rr)
|
||
if existing, exists := extraMap[key]; exists {
|
||
// 如果存在相同记录,选择TTL更长的
|
||
if rr.Header().Ttl > existing.Header().Ttl {
|
||
extraMap[key] = rr
|
||
}
|
||
} else {
|
||
extraMap[key] = rr
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 预分配切片容量,减少扩容开销
|
||
mergedResponse.Answer = make([]dns.RR, 0, len(answerMap))
|
||
mergedResponse.Ns = make([]dns.RR, 0, len(nsMap))
|
||
mergedResponse.Extra = make([]dns.RR, 0, len(extraMap))
|
||
|
||
// 将map转换回切片
|
||
for _, rr := range answerMap {
|
||
mergedResponse.Answer = append(mergedResponse.Answer, rr)
|
||
}
|
||
|
||
for _, rr := range nsMap {
|
||
mergedResponse.Ns = append(mergedResponse.Ns, rr)
|
||
}
|
||
|
||
for _, rr := range extraMap {
|
||
mergedResponse.Extra = append(mergedResponse.Extra, rr)
|
||
}
|
||
|
||
return mergedResponse
|
||
}
|
||
|
||
// updateDNSSECServerMap 更新DNSSEC专用服务器映射,用于快速查找
|
||
func (s *Server) updateDNSSECServerMap() {
|
||
// 清空现有映射
|
||
for k := range s.dnssecServerMap {
|
||
delete(s.dnssecServerMap, k)
|
||
}
|
||
|
||
// 添加所有DNSSEC专用服务器到映射
|
||
for _, server := range s.config.DNSSECUpstreamDNS {
|
||
s.dnssecServerMap[server] = true
|
||
}
|
||
}
|
||
|
||
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
|
||
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) {
|
||
// 始终支持EDNS
|
||
var udpSize uint16 = 4096
|
||
var doFlag bool = s.config.EnableDNSSEC
|
||
|
||
// 检查域名是否匹配不验证DNSSEC的模式
|
||
noDNSSEC := false
|
||
for _, pattern := range s.config.NoDNSSECDomains {
|
||
if strings.Contains(domain, pattern) {
|
||
noDNSSEC = true
|
||
doFlag = false
|
||
logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern)
|
||
break
|
||
}
|
||
}
|
||
|
||
// 检查客户端请求是否包含EDNS记录
|
||
if opt := r.IsEdns0(); opt != nil {
|
||
// 保留客户端的UDP缓冲区大小
|
||
udpSize = opt.UDPSize()
|
||
// 移除现有的EDNS记录,以便重新添加
|
||
for i := range r.Extra {
|
||
if r.Extra[i] == opt {
|
||
r.Extra = append(r.Extra[:i], r.Extra[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加EDNS记录,设置适当的UDPSize和DO标志
|
||
r.SetEdns0(udpSize, doFlag)
|
||
|
||
// DNSSEC专用服务器列表,从配置中获取
|
||
dnssecServers := s.config.DNSSECUpstreamDNS
|
||
|
||
// 选择合适的上游DNS服务器列表
|
||
// 1. 首先检查是否有域名特定的DNS服务器配置
|
||
var selectedUpstreamDNS []string
|
||
var domainMatched bool
|
||
|
||
for matchStr, dnsServers := range s.config.DomainSpecificDNS {
|
||
if strings.Contains(domain, matchStr) {
|
||
selectedUpstreamDNS = dnsServers
|
||
domainMatched = true
|
||
logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers)
|
||
break
|
||
}
|
||
}
|
||
|
||
// 2. 如果没有匹配的域名特定配置
|
||
if !domainMatched {
|
||
// 创建一个新的切片来存储最终的上游服务器列表
|
||
var finalUpstreamDNS []string
|
||
|
||
// 首先添加用户配置的上游DNS服务器
|
||
finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...)
|
||
logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS)
|
||
|
||
// 如果启用了DNSSEC且有配置DNSSEC专用服务器,并且域名不匹配NoDNSSECDomains,则将DNSSEC专用服务器添加到列表中
|
||
if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC {
|
||
// 合并DNSSEC专用服务器到上游服务器列表,避免重复,并确保包含端口号
|
||
for _, dnssecServer := range s.config.DNSSECUpstreamDNS {
|
||
hasDuplicate := false
|
||
// 确保DNSSEC服务器地址包含端口号
|
||
normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer)
|
||
for _, upstream := range finalUpstreamDNS {
|
||
if upstream == normalizedDnssecServer {
|
||
hasDuplicate = true
|
||
break
|
||
}
|
||
}
|
||
if !hasDuplicate {
|
||
finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer)
|
||
}
|
||
}
|
||
logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS)
|
||
}
|
||
|
||
// 使用最终合并后的服务器列表
|
||
selectedUpstreamDNS = finalUpstreamDNS
|
||
}
|
||
|
||
// 1. 首先尝试所有配置的上游DNS服务器
|
||
var bestResponse *dns.Msg
|
||
var bestRtt time.Duration
|
||
var hasBestResponse bool
|
||
var hasDNSSECResponse bool
|
||
var backupResponse *dns.Msg
|
||
var backupRtt time.Duration
|
||
var hasBackup bool
|
||
var usedDNSServer string
|
||
var usedDNSSECServer string
|
||
|
||
// 使用配置中的超时时间
|
||
defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond
|
||
|
||
// 根据查询模式处理请求
|
||
switch s.config.QueryMode {
|
||
case "parallel":
|
||
// 并行请求模式 - 收集所有响应并合并
|
||
responses := make(chan serverResponse, len(selectedUpstreamDNS))
|
||
var wg sync.WaitGroup
|
||
|
||
// 向所有上游服务器并行发送请求,每个请求带有超时
|
||
for _, upstream := range selectedUpstreamDNS {
|
||
wg.Add(1)
|
||
go func(server string) {
|
||
defer wg.Done()
|
||
|
||
// 从池中获取客户端实例
|
||
client := s.clientPool.Get().(*dns.Client)
|
||
// 设置客户端参数
|
||
client.Net = s.resolver.Net
|
||
client.UDPSize = s.resolver.UDPSize
|
||
client.Timeout = defaultTimeout
|
||
|
||
// 发送请求并获取响应,确保服务器地址包含端口号
|
||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
|
||
responses <- serverResponse{response, rtt, server, err}
|
||
|
||
// 将客户端实例放回池中
|
||
s.clientPool.Put(client)
|
||
}(upstream)
|
||
}
|
||
|
||
// 等待所有请求完成或超时
|
||
go func() {
|
||
wg.Wait()
|
||
close(responses)
|
||
}()
|
||
|
||
// 收集成功响应和NXDOMAIN响应分开
|
||
var successResponses []*dns.Msg
|
||
var nxdomainResponses []*dns.Msg
|
||
var totalRtt time.Duration
|
||
var responseCount int
|
||
|
||
// 处理所有响应
|
||
for resp := range responses {
|
||
if resp.error == nil && resp.response != nil {
|
||
// 更新服务器统计信息
|
||
s.updateServerStats(resp.server, true, resp.rtt)
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||
|
||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||
if noDNSSEC {
|
||
resp.response.AuthenticatedData = false
|
||
}
|
||
|
||
// 只对将要返回的响应进行DNSSEC验证,减少开销
|
||
// 这里只设置containsDNSSEC标志,实际验证在确定返回响应后进行
|
||
if containsDNSSEC && s.config.EnableDNSSEC && !noDNSSEC {
|
||
// 暂时不验证,只标记
|
||
}
|
||
|
||
// 检查当前服务器是否是DNSSEC专用服务器(O(1)查找)
|
||
if _, isDNSSECServer := s.dnssecServerMap[resp.server]; isDNSSECServer {
|
||
usedDNSSECServer = resp.server
|
||
}
|
||
|
||
// 收集响应,按Rcode分类
|
||
if resp.response.Rcode == dns.RcodeSuccess {
|
||
successResponses = append(successResponses, resp.response)
|
||
totalRtt += resp.rtt
|
||
responseCount++
|
||
|
||
// 记录使用的服务器
|
||
if usedDNSServer == "" {
|
||
usedDNSServer = resp.server
|
||
}
|
||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||
nxdomainResponses = append(nxdomainResponses, resp.response)
|
||
} else {
|
||
// 更新备选响应,确保总有一个可用的响应
|
||
if resp.response != nil {
|
||
if !hasBackup {
|
||
// 第一次保存备选响应
|
||
backupResponse = resp.response
|
||
backupRtt = resp.rtt
|
||
hasBackup = true
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
// 更新服务器统计信息(失败)
|
||
s.updateServerStats(resp.server, false, 0)
|
||
}
|
||
}
|
||
|
||
// 合并响应:优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应
|
||
var validResponses []*dns.Msg
|
||
if len(successResponses) > 0 {
|
||
validResponses = successResponses
|
||
} else {
|
||
validResponses = nxdomainResponses
|
||
}
|
||
|
||
// 合并所有有效响应
|
||
if len(validResponses) > 0 {
|
||
bestResponse = mergeResponses(validResponses)
|
||
if responseCount > 0 {
|
||
bestRtt = totalRtt / time.Duration(responseCount)
|
||
}
|
||
hasBestResponse = true
|
||
// 设置日志的type字段
|
||
logType := "success"
|
||
if len(successResponses) == 0 {
|
||
logType = "nxdomain"
|
||
}
|
||
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses), "type", logType)
|
||
}
|
||
|
||
case "fastest-ip":
|
||
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
|
||
// 1. 选择最快的服务器
|
||
fastestServer := s.selectFastestServer(selectedUpstreamDNS)
|
||
if fastestServer != "" {
|
||
// 使用带超时的方式执行Exchange
|
||
resultChan := make(chan struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}, 1)
|
||
|
||
go func() {
|
||
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(fastestServer))
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}{resp, r, e}
|
||
}()
|
||
|
||
var response *dns.Msg
|
||
var rtt time.Duration
|
||
var err error
|
||
|
||
// 直接获取结果,不使用上下文超时
|
||
result := <-resultChan
|
||
response, rtt, err = result.response, result.rtt, result.err
|
||
if err == nil && response != nil {
|
||
// 更新服务器统计信息
|
||
s.updateServerStats(fastestServer, true, rtt)
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
containsDNSSEC := s.hasDNSSECRecords(response)
|
||
|
||
// 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名
|
||
// 但如果域名匹配不验证DNSSEC的模式,则跳过验证
|
||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||
// 验证DNSSEC记录
|
||
signatureValid := s.verifyDNSSEC(response)
|
||
|
||
// 设置AD标志(Authenticated Data)
|
||
response.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
// 更新DNSSEC验证成功计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
// 更新DNSSEC验证失败计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
} else if noDNSSEC {
|
||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||
response.AuthenticatedData = false
|
||
}
|
||
|
||
// 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应
|
||
if response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError {
|
||
if response.Rcode == dns.RcodeSuccess {
|
||
// 优先选择带有DNSSEC记录的响应
|
||
if containsDNSSEC {
|
||
bestResponse = response
|
||
bestRtt = rtt
|
||
hasBestResponse = true
|
||
hasDNSSECResponse = true
|
||
usedDNSServer = fastestServer
|
||
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer {
|
||
usedDNSSECServer = fastestServer
|
||
}
|
||
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
|
||
} else {
|
||
// 没有带DNSSEC的响应时,保存成功响应
|
||
bestResponse = response
|
||
bestRtt = rtt
|
||
hasBestResponse = true
|
||
usedDNSServer = fastestServer
|
||
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(fastestServer)]; isDNSSECServer {
|
||
usedDNSSECServer = fastestServer
|
||
}
|
||
logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
|
||
}
|
||
} else if response.Rcode == dns.RcodeNameError {
|
||
// 处理NXDOMAIN响应
|
||
bestResponse = response
|
||
bestRtt = rtt
|
||
hasBestResponse = true
|
||
usedDNSServer = fastestServer
|
||
logger.Debug("找到NXDOMAIN响应", "domain", domain, "server", fastestServer, "rtt", rtt)
|
||
}
|
||
// 保存为备选响应
|
||
if !hasBackup {
|
||
backupResponse = response
|
||
backupRtt = rtt
|
||
hasBackup = true
|
||
}
|
||
}
|
||
} else {
|
||
// 更新服务器统计信息(失败)
|
||
s.updateServerStats(fastestServer, false, 0)
|
||
}
|
||
}
|
||
|
||
default:
|
||
// 默认使用并行请求模式 - 实现快速返回和超时机制
|
||
responses := make(chan serverResponse, len(selectedUpstreamDNS))
|
||
resultChan := make(chan struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
usedServer string
|
||
usedDnssecServer string
|
||
}, 1)
|
||
var wg sync.WaitGroup
|
||
|
||
// 向所有上游服务器并行发送请求
|
||
for _, upstream := range selectedUpstreamDNS {
|
||
wg.Add(1)
|
||
go func(server string) {
|
||
defer wg.Done()
|
||
|
||
// 创建带有超时的resolver
|
||
client := &dns.Client{
|
||
Net: s.resolver.Net,
|
||
UDPSize: s.resolver.UDPSize,
|
||
Timeout: defaultTimeout,
|
||
}
|
||
|
||
// 发送请求并获取响应,确保服务器地址包含端口号
|
||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
|
||
responses <- serverResponse{response, rtt, server, err}
|
||
}(upstream)
|
||
}
|
||
|
||
// 处理响应的协程
|
||
go func() {
|
||
var fastestResponse *dns.Msg
|
||
var fastestRtt time.Duration = defaultTimeout
|
||
var fastestServer string
|
||
var fastestDnssecServer string
|
||
var fastestHasDnssec bool
|
||
var successResponses []*dns.Msg
|
||
var nxdomainResponses []*dns.Msg
|
||
|
||
// 等待所有请求完成或超时
|
||
timer := time.NewTimer(defaultTimeout)
|
||
defer timer.Stop()
|
||
|
||
// 处理所有响应
|
||
for {
|
||
select {
|
||
case resp, ok := <-responses:
|
||
if !ok {
|
||
// 所有响应都已处理
|
||
goto doneProcessing
|
||
}
|
||
|
||
if resp.error == nil && resp.response != nil {
|
||
// 更新服务器统计信息
|
||
s.updateServerStats(resp.server, true, resp.rtt)
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||
|
||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||
if noDNSSEC {
|
||
resp.response.AuthenticatedData = false
|
||
}
|
||
|
||
dnssecServerForResponse := ""
|
||
if _, isDNSSECServer := s.dnssecServerMap[normalizeDNSServerAddress(resp.server)]; isDNSSECServer {
|
||
dnssecServerForResponse = resp.server
|
||
}
|
||
|
||
// 如果响应成功或为NXDOMAIN
|
||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||
// 按Rcode分类添加到不同列表
|
||
if resp.response.Rcode == dns.RcodeSuccess {
|
||
successResponses = append(successResponses, resp.response)
|
||
} else {
|
||
nxdomainResponses = append(nxdomainResponses, resp.response)
|
||
}
|
||
|
||
// 快速返回逻辑:找到第一个有效响应或更快的响应
|
||
if resp.response.Rcode == dns.RcodeSuccess {
|
||
// 优先选择带有DNSSEC的响应
|
||
if containsDNSSEC {
|
||
// 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快
|
||
if !fastestHasDnssec || resp.rtt < fastestRtt {
|
||
fastestResponse = resp.response
|
||
fastestRtt = resp.rtt
|
||
fastestServer = resp.server
|
||
fastestDnssecServer = dnssecServerForResponse
|
||
fastestHasDnssec = true
|
||
|
||
// 只对将要返回的响应进行DNSSEC验证
|
||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||
// 验证DNSSEC记录
|
||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||
|
||
// 设置AD标志(Authenticated Data)
|
||
fastestResponse.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
// 更新DNSSEC验证成功计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
// 更新DNSSEC验证失败计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
}
|
||
|
||
// 发送结果,快速返回
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
usedServer string
|
||
usedDnssecServer string
|
||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||
}
|
||
} else {
|
||
// 非DNSSEC响应,只有在还没有找到DNSSEC响应且当前响应更快时才更新
|
||
if !fastestHasDnssec && resp.rtt < fastestRtt {
|
||
fastestResponse = resp.response
|
||
fastestRtt = resp.rtt
|
||
fastestServer = resp.server
|
||
fastestDnssecServer = dnssecServerForResponse
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
|
||
|
||
// 只对将要返回的响应进行DNSSEC验证
|
||
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
|
||
// 验证DNSSEC记录
|
||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||
|
||
// 设置AD标志(Authenticated Data)
|
||
fastestResponse.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
// 更新DNSSEC验证成功计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
// 更新DNSSEC验证失败计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
}
|
||
|
||
// 发送结果,快速返回
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
usedServer string
|
||
usedDnssecServer string
|
||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||
}
|
||
}
|
||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||
// NXDOMAIN响应,只有在还没有找到响应或当前响应更快时才更新
|
||
if !fastestHasDnssec && resp.rtt < fastestRtt {
|
||
fastestResponse = resp.response
|
||
fastestRtt = resp.rtt
|
||
fastestServer = resp.server
|
||
fastestDnssecServer = dnssecServerForResponse
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
|
||
|
||
// 只对将要返回的响应进行DNSSEC验证
|
||
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
|
||
// 验证DNSSEC记录
|
||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||
|
||
// 设置AD标志(Authenticated Data)
|
||
fastestResponse.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
// 更新DNSSEC验证成功计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
// 更新DNSSEC验证失败计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
}
|
||
|
||
// 发送结果,快速返回
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
usedServer string
|
||
usedDnssecServer string
|
||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||
}
|
||
}
|
||
} else {
|
||
// 更新备选响应,确保总有一个可用的响应
|
||
if resp.response != nil {
|
||
if !hasBackup {
|
||
// 第一次保存备选响应
|
||
backupResponse = resp.response
|
||
backupRtt = resp.rtt
|
||
hasBackup = true
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
// 更新服务器统计信息(失败)
|
||
s.updateServerStats(resp.server, false, 0)
|
||
}
|
||
case <-timer.C:
|
||
// 超时,停止等待更多响应
|
||
goto doneProcessing
|
||
}
|
||
}
|
||
|
||
doneProcessing:
|
||
// 合并响应,优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应
|
||
var validResponses []*dns.Msg
|
||
if len(successResponses) > 0 {
|
||
validResponses = successResponses
|
||
} else {
|
||
validResponses = nxdomainResponses
|
||
}
|
||
|
||
// 合并所有有效响应,用于缓存
|
||
if len(validResponses) > 1 {
|
||
mergedResponse := mergeResponses(validResponses)
|
||
if mergedResponse != nil {
|
||
// 只在合并后的响应比最快响应更好时才使用
|
||
mergedHasDnssec := s.hasDNSSECRecords(mergedResponse)
|
||
if mergedHasDnssec && !fastestHasDnssec {
|
||
// 合并后的响应有DNSSEC,而最快响应没有,使用合并后的响应
|
||
fastestResponse = mergedResponse
|
||
// 使用最快的Rtt作为合并响应的Rtt
|
||
fastestHasDnssec = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果还没有发送结果,发送最快的响应
|
||
if fastestResponse != nil {
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
usedServer string
|
||
usedDnssecServer string
|
||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||
}
|
||
close(resultChan)
|
||
}()
|
||
|
||
// 等待所有请求完成(不阻塞主流程)
|
||
go func() {
|
||
wg.Wait()
|
||
close(responses)
|
||
}()
|
||
|
||
// 等待结果或超时
|
||
select {
|
||
case result := <-resultChan:
|
||
// 快速返回结果
|
||
bestResponse = result.response
|
||
bestRtt = result.rtt
|
||
usedDNSServer = result.usedServer
|
||
usedDNSSECServer = result.usedDnssecServer
|
||
hasBestResponse = true
|
||
hasDNSSECResponse = s.hasDNSSECRecords(result.response)
|
||
logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse)
|
||
case <-time.After(defaultTimeout):
|
||
// 超时,使用备选响应
|
||
logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout)
|
||
}
|
||
}
|
||
|
||
// 2. 当启用DNSSEC且没有找到带DNSSEC的响应时,向DNSSEC专用服务器发送请求
|
||
// 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains,则不使用DNSSEC专用服务器,只使用指定的DNS服务器
|
||
if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC {
|
||
logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain)
|
||
|
||
// 增加DNSSEC查询计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
})
|
||
|
||
// 无论查询模式是什么,DNSSEC验证都只使用加权随机选择一个服务器
|
||
selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers)
|
||
if selectedDnssecServer != "" {
|
||
// 使用带超时的方式执行Exchange
|
||
resultChan := make(chan struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}, 1)
|
||
|
||
go func() {
|
||
// 创建带有超时的resolver
|
||
client := &dns.Client{
|
||
Net: s.resolver.Net,
|
||
UDPSize: s.resolver.UDPSize,
|
||
Timeout: defaultTimeout,
|
||
}
|
||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}{response, rtt, err}
|
||
}()
|
||
|
||
var response *dns.Msg
|
||
var rtt time.Duration
|
||
var err error
|
||
|
||
// 使用超时获取结果
|
||
select {
|
||
case result := <-resultChan:
|
||
response, rtt, err = result.response, result.rtt, result.err
|
||
case <-time.After(defaultTimeout):
|
||
// 超时,不再等待
|
||
logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout)
|
||
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
|
||
}
|
||
|
||
if err == nil && response != nil {
|
||
// 更新服务器统计信息
|
||
s.updateServerStats(selectedDnssecServer, true, rtt)
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
containsDNSSEC := s.hasDNSSECRecords(response)
|
||
|
||
if response.Rcode == dns.RcodeSuccess {
|
||
// 无论响应是否包含DNSSEC记录,只要使用了DNSSEC专用服务器,就设置usedDNSSECServer
|
||
usedDNSSECServer = selectedDnssecServer
|
||
|
||
// 验证DNSSEC记录
|
||
signatureValid := s.verifyDNSSEC(response)
|
||
|
||
// 设置AD标志(Authenticated Data)
|
||
response.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
// 更新DNSSEC验证成功计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
// 更新DNSSEC验证失败计数
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
|
||
// 优先使用DNSSEC专用服务器的响应,尤其是带有DNSSEC记录的
|
||
if containsDNSSEC {
|
||
bestResponse = response
|
||
bestRtt = rtt
|
||
hasBestResponse = true
|
||
hasDNSSECResponse = true
|
||
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应,优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt)
|
||
}
|
||
// 注意:如果DNSSEC专用服务器返回的响应不包含DNSSEC记录,
|
||
// 我们不会覆盖之前从upstreamDNS获取的响应,
|
||
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
|
||
|
||
// 更新备选响应
|
||
if !hasBackup {
|
||
backupResponse = response
|
||
backupRtt = rtt
|
||
hasBackup = true
|
||
}
|
||
}
|
||
} else {
|
||
// 更新服务器统计信息(失败)
|
||
s.updateServerStats(selectedDnssecServer, false, 0)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 返回最佳响应
|
||
if hasBestResponse {
|
||
// 检查最佳响应是否包含DNSSEC记录
|
||
bestHasDNSSEC := s.hasDNSSECRecords(bestResponse)
|
||
|
||
// 如果启用了DNSSEC且最佳响应不包含DNSSEC记录,尝试使用本地解析(使用upstreamDNS服务器)
|
||
// 但如果域名匹配了domainSpecificDNS配置,则不执行此逻辑,只使用指定的DNS服务器
|
||
if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched {
|
||
logger.Debug("最佳响应不包含DNSSEC记录,尝试使用本地解析(upstreamDNS)", "domain", domain)
|
||
// 选择一个upstreamDNS服务器进行解析(使用加权随机算法)
|
||
localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
|
||
if localServer != "" {
|
||
// 使用带超时的方式执行Exchange
|
||
resultChan := make(chan struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}, 1)
|
||
|
||
go func() {
|
||
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer))
|
||
resultChan <- struct {
|
||
response *dns.Msg
|
||
rtt time.Duration
|
||
err error
|
||
}{resp, r, e}
|
||
}()
|
||
|
||
var localResponse *dns.Msg
|
||
var rtt time.Duration
|
||
var err error
|
||
|
||
// 直接获取结果,不使用上下文超时
|
||
result := <-resultChan
|
||
localResponse, rtt, err = result.response, result.rtt, result.err
|
||
|
||
if err == nil && localResponse != nil {
|
||
// 更新服务器统计信息
|
||
s.updateServerStats(localServer, true, rtt)
|
||
|
||
// 检查是否包含DNSSEC记录
|
||
localHasDNSSEC := s.hasDNSSECRecords(localResponse)
|
||
|
||
// 验证DNSSEC记录(如果存在),但不影响最终响应
|
||
if localHasDNSSEC {
|
||
signatureValid := s.verifyDNSSEC(localResponse)
|
||
localResponse.AuthenticatedData = signatureValid
|
||
|
||
if signatureValid {
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECSuccess++
|
||
})
|
||
} else {
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.DNSSECQueries++
|
||
stats.DNSSECFailed++
|
||
})
|
||
}
|
||
}
|
||
|
||
// 记录解析域名统计
|
||
s.updateResolvedDomainStats(domain)
|
||
|
||
// 更新域名的DNSSEC状态
|
||
s.updateDomainDNSSECStatus(domain, localHasDNSSEC)
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
|
||
logger.Debug("使用本地解析结果(upstreamDNS)", "domain", domain, "server", localServer, "rtt", rtt)
|
||
return localResponse, rtt, localServer, ""
|
||
} else {
|
||
// 更新服务器统计信息(失败)
|
||
s.updateServerStats(localServer, false, 0)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 记录解析域名统计
|
||
s.updateResolvedDomainStats(domain)
|
||
|
||
// 更新域名的DNSSEC状态
|
||
if bestHasDNSSEC {
|
||
s.updateDomainDNSSECStatus(domain, true)
|
||
} else {
|
||
s.updateDomainDNSSECStatus(domain, false)
|
||
}
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
|
||
}
|
||
|
||
// 如果有备选响应,返回该响应
|
||
if hasBackup {
|
||
logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain)
|
||
// 记录解析域名统计
|
||
s.updateResolvedDomainStats(domain)
|
||
// 更新统计信息
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Allowed++
|
||
})
|
||
return backupResponse, backupRtt, "", ""
|
||
}
|
||
|
||
// 所有上游服务器都失败,返回服务器失败错误
|
||
response := new(dns.Msg)
|
||
response.SetReply(r)
|
||
|
||
response.SetRcode(r, dns.RcodeServerFailure)
|
||
|
||
logger.Error("DNS查询失败", "domain", domain)
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.Errors++
|
||
})
|
||
return response, 0, "", ""
|
||
}
|
||
|
||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||
response, _, _, _ := s.forwardDNSRequestWithCache(r, domain)
|
||
w.WriteMsg(response)
|
||
}
|
||
|
||
// updateBlockedDomainStats 更新被屏蔽域名统计
|
||
func (s *Server) updateBlockedDomainStats(domain string) {
|
||
// 先尝试读锁,检查条目是否存在
|
||
s.blockedDomainsMutex.RLock()
|
||
entry, exists := s.blockedDomains[domain]
|
||
s.blockedDomainsMutex.RUnlock()
|
||
|
||
if exists {
|
||
// 使用原子操作更新计数和时间戳
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
// 获取写锁,创建新条目
|
||
s.blockedDomainsMutex.Lock()
|
||
// 再次检查,避免竞态条件
|
||
if entry, exists := s.blockedDomains[domain]; exists {
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
s.blockedDomains[domain] = &BlockedDomain{
|
||
Domain: domain,
|
||
Count: 1,
|
||
LastSeen: time.Now().UnixNano(),
|
||
DNSSEC: false,
|
||
}
|
||
}
|
||
s.blockedDomainsMutex.Unlock()
|
||
}
|
||
|
||
// 更新统计数据
|
||
now := time.Now()
|
||
|
||
// 更新小时统计
|
||
hourKey := now.Format("2006-01-02-15")
|
||
s.hourlyStatsMutex.Lock()
|
||
s.hourlyStats[hourKey]++
|
||
s.hourlyStatsMutex.Unlock()
|
||
|
||
// 更新每日统计
|
||
dayKey := now.Format("2006-01-02")
|
||
s.dailyStatsMutex.Lock()
|
||
s.dailyStats[dayKey]++
|
||
s.dailyStatsMutex.Unlock()
|
||
|
||
// 更新每月统计
|
||
monthKey := now.Format("2006-01")
|
||
s.monthlyStatsMutex.Lock()
|
||
s.monthlyStats[monthKey]++
|
||
s.monthlyStatsMutex.Unlock()
|
||
}
|
||
|
||
// updateClientStats 更新客户端统计
|
||
func (s *Server) updateClientStats(ip string) {
|
||
// 先尝试读锁,检查条目是否存在
|
||
s.clientStatsMutex.RLock()
|
||
entry, exists := s.clientStats[ip]
|
||
s.clientStatsMutex.RUnlock()
|
||
|
||
if exists {
|
||
// 使用原子操作更新计数和时间戳
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
// 获取写锁,创建新条目
|
||
s.clientStatsMutex.Lock()
|
||
// 再次检查,避免竞态条件
|
||
if entry, exists := s.clientStats[ip]; exists {
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
s.clientStats[ip] = &ClientStats{
|
||
IP: ip,
|
||
Count: 1,
|
||
LastSeen: time.Now().UnixNano(),
|
||
}
|
||
}
|
||
s.clientStatsMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
|
||
func (s *Server) hasDNSSECRecords(response *dns.Msg) bool {
|
||
// 直接调用包内的hasDNSSECRecords函数,避免重复代码
|
||
return hasDNSSECRecords(response)
|
||
}
|
||
|
||
// verifyDNSSEC 验证DNSSEC签名
|
||
func (s *Server) verifyDNSSEC(response *dns.Msg) bool {
|
||
// 提取DNSKEY和RRSIG记录,并按类型和名称组织记录
|
||
dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY
|
||
rrsigs := make([]*dns.RRSIG, 0)
|
||
// 按 (名称, 类型) 组织记录集,用于快速查找
|
||
rrSets := make(map[string]map[uint16][]dns.RR) // name -> type -> records
|
||
|
||
// 定义处理单个记录的函数
|
||
processRecord := func(rr dns.RR) {
|
||
num := rr.Header().Rrtype
|
||
name := rr.Header().Name
|
||
|
||
// 组织记录集
|
||
if _, exists := rrSets[name]; !exists {
|
||
rrSets[name] = make(map[uint16][]dns.RR)
|
||
}
|
||
if _, exists := rrSets[name][num]; !exists {
|
||
rrSets[name][num] = make([]dns.RR, 0)
|
||
}
|
||
rrSets[name][num] = append(rrSets[name][num], rr)
|
||
|
||
// 特别处理DNSKEY和RRSIG
|
||
if dnskey, ok := rr.(*dns.DNSKEY); ok {
|
||
tag := dnskey.KeyTag()
|
||
dnskeys[tag] = dnskey
|
||
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
|
||
rrsigs = append(rrsigs, rrsig)
|
||
}
|
||
}
|
||
|
||
// 一次遍历所有响应部分,同时完成记录收集和组织
|
||
for _, rr := range response.Answer {
|
||
processRecord(rr)
|
||
}
|
||
for _, rr := range response.Ns {
|
||
processRecord(rr)
|
||
}
|
||
for _, rr := range response.Extra {
|
||
processRecord(rr)
|
||
}
|
||
|
||
// 如果没有RRSIG记录,验证失败
|
||
if len(rrsigs) == 0 {
|
||
return false
|
||
}
|
||
|
||
// 验证所有RRSIG记录
|
||
signatureValid := true
|
||
// 用于记录已经警告过的DNSKEY tag,避免重复警告
|
||
warnedKeyTags := make(map[uint16]bool)
|
||
for _, rrsig := range rrsigs {
|
||
// 查找对应的DNSKEY
|
||
dnskey, exists := dnskeys[rrsig.KeyTag]
|
||
if !exists {
|
||
// 仅当该key_tag尚未警告过时才记录警告
|
||
if !warnedKeyTags[rrsig.KeyTag] {
|
||
logger.Warn("DNSSEC验证失败:找不到对应的DNSKEY", "key_tag", rrsig.KeyTag)
|
||
warnedKeyTags[rrsig.KeyTag] = true
|
||
}
|
||
signatureValid = false
|
||
continue
|
||
}
|
||
|
||
// 快速查找需要验证的记录集
|
||
name := rrsig.Header().Name
|
||
typeCovered := rrsig.TypeCovered
|
||
rrset := rrSets[name][typeCovered]
|
||
|
||
// 验证签名
|
||
if len(rrset) > 0 {
|
||
err := rrsig.Verify(dnskey, rrset)
|
||
if err != nil {
|
||
logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag)
|
||
signatureValid = false
|
||
} else {
|
||
logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag)
|
||
}
|
||
}
|
||
}
|
||
|
||
return signatureValid
|
||
}
|
||
|
||
// updateDomainDNSSECStatus 更新域名的DNSSEC状态
|
||
func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) {
|
||
// 确保域名是小写
|
||
domain = strings.ToLower(domain)
|
||
|
||
// 更新域名的DNSSEC状态
|
||
s.resolvedDomainsMutex.Lock()
|
||
defer s.resolvedDomainsMutex.Unlock()
|
||
|
||
// 更新resolvedDomains中的DNSSEC状态
|
||
if entry, exists := s.resolvedDomains[domain]; exists {
|
||
entry.DNSSEC = dnssec
|
||
} else {
|
||
s.resolvedDomains[domain] = &BlockedDomain{
|
||
Domain: domain,
|
||
Count: 1,
|
||
LastSeen: time.Now().UnixNano(),
|
||
DNSSEC: dnssec,
|
||
}
|
||
}
|
||
|
||
// 更新domainDNSSECStatus映射(使用单独的锁)
|
||
s.domainDNSSECStatusMutex.Lock()
|
||
s.domainDNSSECStatus[domain] = dnssec
|
||
s.domainDNSSECStatusMutex.Unlock()
|
||
}
|
||
|
||
// updateResolvedDomainStats 更新解析域名统计
|
||
func (s *Server) updateResolvedDomainStats(domain string) {
|
||
// 先尝试读锁,检查条目是否存在
|
||
s.resolvedDomainsMutex.RLock()
|
||
entry, exists := s.resolvedDomains[domain]
|
||
s.resolvedDomainsMutex.RUnlock()
|
||
|
||
if exists {
|
||
// 使用原子操作更新计数和时间戳
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
// 获取写锁,创建新条目
|
||
s.resolvedDomainsMutex.Lock()
|
||
// 再次检查,避免竞态条件
|
||
if entry, exists := s.resolvedDomains[domain]; exists {
|
||
atomic.AddInt64(&entry.Count, 1)
|
||
atomic.StoreInt64(&entry.LastSeen, time.Now().UnixNano())
|
||
} else {
|
||
s.resolvedDomains[domain] = &BlockedDomain{
|
||
Domain: domain,
|
||
Count: 1,
|
||
LastSeen: time.Now().UnixNano(),
|
||
DNSSEC: false,
|
||
}
|
||
}
|
||
s.resolvedDomainsMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// getServerStats 获取服务器统计信息,如果不存在则创建
|
||
func (s *Server) getServerStats(server string) *ServerStats {
|
||
s.serverStatsMutex.RLock()
|
||
stats, exists := s.serverStats[server]
|
||
s.serverStatsMutex.RUnlock()
|
||
|
||
if exists {
|
||
return stats
|
||
}
|
||
|
||
s.serverStatsMutex.Lock()
|
||
defer s.serverStatsMutex.Unlock()
|
||
|
||
if stats, exists := s.serverStats[server]; exists {
|
||
return stats
|
||
}
|
||
|
||
stats = &ServerStats{
|
||
SuccessCount: 0,
|
||
FailureCount: 0,
|
||
LastResponse: time.Now(),
|
||
ResponseTime: 0,
|
||
ConnectionSpeed: 0,
|
||
}
|
||
|
||
s.serverStats[server] = stats
|
||
return stats
|
||
}
|
||
|
||
// updateServerStats 更新服务器统计信息
|
||
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
|
||
stats := s.getServerStats(server)
|
||
|
||
// 使用原子操作更新成功和失败计数
|
||
if success {
|
||
successCount := atomic.AddInt64(&stats.SuccessCount, 1)
|
||
|
||
// 只在需要更新平均响应时间时获取锁
|
||
s.serverStatsMutex.Lock()
|
||
stats.LastResponse = time.Now()
|
||
|
||
// 更新平均响应时间(简单移动平均)
|
||
if successCount == 1 {
|
||
// 第一次成功,直接使用当前响应时间
|
||
stats.ResponseTime = rtt
|
||
} else {
|
||
// 使用纳秒进行计算以避免类型不匹配
|
||
prevTotal := stats.ResponseTime.Nanoseconds() * (successCount - 1)
|
||
newTotal := prevTotal + rtt.Nanoseconds()
|
||
stats.ResponseTime = time.Duration(newTotal / successCount)
|
||
}
|
||
s.serverStatsMutex.Unlock()
|
||
} else {
|
||
atomic.AddInt64(&stats.FailureCount, 1)
|
||
|
||
// 只更新LastResponse时获取锁
|
||
s.serverStatsMutex.Lock()
|
||
stats.LastResponse = time.Now()
|
||
s.serverStatsMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// selectWeightedRandomServer 加权随机选择服务器
|
||
func (s *Server) selectWeightedRandomServer(servers []string) string {
|
||
if len(servers) == 0 {
|
||
return ""
|
||
}
|
||
|
||
if len(servers) == 1 {
|
||
return servers[0]
|
||
}
|
||
|
||
type serverWeight struct {
|
||
server string
|
||
weight int64
|
||
responseTime time.Duration
|
||
successCount int64
|
||
failureCount int64
|
||
}
|
||
|
||
var totalWeight int64
|
||
var totalResponseTime time.Duration
|
||
var validServers int
|
||
var currentWeight int64
|
||
|
||
serversInfo := make([]serverWeight, len(servers))
|
||
|
||
for i, server := range servers {
|
||
stats := s.getServerStats(server)
|
||
|
||
serversInfo[i] = serverWeight{
|
||
server: server,
|
||
responseTime: stats.ResponseTime,
|
||
successCount: atomic.LoadInt64(&stats.SuccessCount),
|
||
failureCount: atomic.LoadInt64(&stats.FailureCount),
|
||
}
|
||
|
||
if stats.ResponseTime > 0 {
|
||
totalResponseTime += stats.ResponseTime
|
||
validServers++
|
||
}
|
||
}
|
||
|
||
var avgResponseTime time.Duration
|
||
if validServers > 0 {
|
||
avgResponseTime = totalResponseTime / time.Duration(validServers)
|
||
} else {
|
||
avgResponseTime = 1 * time.Second
|
||
}
|
||
|
||
var randomGen = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||
|
||
for i := range serversInfo {
|
||
baseWeight := serversInfo[i].successCount - serversInfo[i].failureCount*2
|
||
if baseWeight < 1 {
|
||
baseWeight = 1
|
||
}
|
||
|
||
var responseFactor int64 = 100
|
||
if serversInfo[i].responseTime > 0 {
|
||
if serversInfo[i].responseTime < avgResponseTime {
|
||
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
|
||
if factor > 200 {
|
||
factor = 200
|
||
}
|
||
responseFactor = factor
|
||
} else {
|
||
factor := (avgResponseTime.Nanoseconds() * 200) / serversInfo[i].responseTime.Nanoseconds()
|
||
if factor < 50 {
|
||
factor = 50
|
||
}
|
||
responseFactor = factor
|
||
}
|
||
}
|
||
|
||
finalWeight := (baseWeight * responseFactor) / 100
|
||
if finalWeight < 1 {
|
||
finalWeight = 1
|
||
}
|
||
|
||
serversInfo[i].weight = finalWeight
|
||
totalWeight += finalWeight
|
||
}
|
||
|
||
random := randomGen.Int63n(totalWeight)
|
||
|
||
for _, sw := range serversInfo {
|
||
currentWeight += sw.weight
|
||
if random < currentWeight {
|
||
return sw.server
|
||
}
|
||
}
|
||
|
||
// 兜底返回第一个服务器
|
||
return servers[0]
|
||
}
|
||
|
||
// measureServerSpeed 测量服务器TCP连接速度
|
||
func (s *Server) measureServerSpeed(server string) time.Duration {
|
||
addr := server
|
||
if !strings.Contains(server, ":") {
|
||
addr = server + ":53"
|
||
}
|
||
|
||
startTime := time.Now()
|
||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||
if err != nil {
|
||
return 2 * time.Second
|
||
}
|
||
defer conn.Close()
|
||
|
||
connTime := time.Since(startTime)
|
||
|
||
stats := s.getServerStats(server)
|
||
s.serverStatsMutex.Lock()
|
||
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
|
||
s.serverStatsMutex.Unlock()
|
||
|
||
return connTime
|
||
}
|
||
|
||
// selectFastestServer 选择连接速度最快的服务器
|
||
func (s *Server) selectFastestServer(servers []string) string {
|
||
if len(servers) == 0 {
|
||
return ""
|
||
}
|
||
|
||
if len(servers) == 1 {
|
||
return servers[0]
|
||
}
|
||
|
||
// 并行测量所有服务器的速度
|
||
type speedResult struct {
|
||
server string
|
||
speed time.Duration
|
||
}
|
||
|
||
results := make(chan speedResult, len(servers))
|
||
var wg sync.WaitGroup
|
||
|
||
for _, server := range servers {
|
||
wg.Add(1)
|
||
go func(srv string) {
|
||
defer wg.Done()
|
||
speed := s.measureServerSpeed(srv)
|
||
results <- speedResult{srv, speed}
|
||
}(server)
|
||
}
|
||
|
||
// 等待所有测量完成
|
||
go func() {
|
||
wg.Wait()
|
||
close(results)
|
||
}()
|
||
|
||
// 找出最快的服务器
|
||
var fastestServer string
|
||
var fastestSpeed time.Duration = 2 * time.Second
|
||
|
||
for result := range results {
|
||
if result.speed < fastestSpeed {
|
||
fastestSpeed = result.speed
|
||
fastestServer = result.server
|
||
}
|
||
}
|
||
|
||
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
|
||
if fastestServer == "" {
|
||
fastestServer = servers[0]
|
||
}
|
||
|
||
return fastestServer
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (s *Server) updateStats(update func(*Stats)) {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
update(s.stats)
|
||
}
|
||
|
||
// addQueryLog 添加查询日志
|
||
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
|
||
// 创建日志记录
|
||
log := QueryLog{
|
||
Timestamp: time.Now(),
|
||
ClientIP: clientIP,
|
||
Domain: domain,
|
||
QueryType: queryType,
|
||
ResponseTime: responseTime,
|
||
Result: result,
|
||
BlockRule: blockRule,
|
||
BlockType: blockType,
|
||
FromCache: fromCache,
|
||
DNSSEC: dnssec,
|
||
EDNS: edns,
|
||
DNSServer: dnsServer,
|
||
DNSSECServer: dnssecServer,
|
||
Answers: answers,
|
||
ResponseCode: responseCode,
|
||
}
|
||
|
||
// 发送到日志处理通道(非阻塞)
|
||
select {
|
||
case s.logChannel <- log:
|
||
// 日志发送成功
|
||
default:
|
||
// 通道已满,丢弃日志以避免阻塞请求处理
|
||
logger.Warn("日志通道已满,丢弃一条日志记录")
|
||
}
|
||
}
|
||
|
||
// GetStartTime 获取服务器启动时间
|
||
func (s *Server) GetStartTime() time.Time {
|
||
return s.startTime
|
||
}
|
||
|
||
// GetStats 获取DNS服务器统计信息
|
||
func (s *Server) GetStats() *Stats {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
|
||
// 复制查询类型统计
|
||
queryTypesCopy := make(map[string]int64)
|
||
for k, v := range s.stats.QueryTypes {
|
||
queryTypesCopy[k] = v
|
||
}
|
||
|
||
// 复制来源IP统计
|
||
sourceIPsCopy := make(map[string]bool)
|
||
for ip := range s.stats.SourceIPs {
|
||
sourceIPsCopy[ip] = true
|
||
}
|
||
|
||
// 返回统计信息的副本
|
||
return &Stats{
|
||
Queries: s.stats.Queries,
|
||
Blocked: s.stats.Blocked,
|
||
Allowed: s.stats.Allowed,
|
||
Errors: s.stats.Errors,
|
||
LastQuery: s.stats.LastQuery,
|
||
AvgResponseTime: s.stats.AvgResponseTime,
|
||
TotalResponseTime: s.stats.TotalResponseTime,
|
||
QueryTypes: queryTypesCopy,
|
||
SourceIPs: sourceIPsCopy,
|
||
CpuUsage: s.stats.CpuUsage,
|
||
DNSSECQueries: s.stats.DNSSECQueries,
|
||
DNSSECSuccess: s.stats.DNSSECSuccess,
|
||
DNSSECFailed: s.stats.DNSSECFailed,
|
||
DNSSECEnabled: s.stats.DNSSECEnabled,
|
||
}
|
||
}
|
||
|
||
// GetQueryLogs 获取查询日志
|
||
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
|
||
s.queryLogsMutex.RLock()
|
||
defer s.queryLogsMutex.RUnlock()
|
||
|
||
// 确保偏移量和限制值合理
|
||
if offset < 0 {
|
||
offset = 0
|
||
}
|
||
if limit <= 0 {
|
||
limit = 100 // 默认返回100条日志
|
||
}
|
||
|
||
// 创建日志副本用于过滤和排序
|
||
var logsCopy []QueryLog
|
||
|
||
// 先过滤日志
|
||
for _, log := range s.queryLogs {
|
||
// 应用结果过滤
|
||
if resultFilter != "" && log.Result != resultFilter {
|
||
continue
|
||
}
|
||
|
||
// 应用搜索过滤
|
||
if searchTerm != "" {
|
||
// 搜索域名或客户端IP
|
||
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
|
||
continue
|
||
}
|
||
}
|
||
|
||
logsCopy = append(logsCopy, log)
|
||
}
|
||
|
||
// 排序日志
|
||
if sortField != "" {
|
||
sort.Slice(logsCopy, func(i, j int) bool {
|
||
var a, b interface{}
|
||
switch sortField {
|
||
case "time":
|
||
a = logsCopy[i].Timestamp
|
||
b = logsCopy[j].Timestamp
|
||
case "clientIp":
|
||
a = logsCopy[i].ClientIP
|
||
b = logsCopy[j].ClientIP
|
||
case "domain":
|
||
a = logsCopy[i].Domain
|
||
b = logsCopy[j].Domain
|
||
case "responseTime":
|
||
a = logsCopy[i].ResponseTime
|
||
b = logsCopy[j].ResponseTime
|
||
case "blockRule":
|
||
a = logsCopy[i].BlockRule
|
||
b = logsCopy[j].BlockRule
|
||
default:
|
||
// 默认按时间排序
|
||
a = logsCopy[i].Timestamp
|
||
b = logsCopy[j].Timestamp
|
||
}
|
||
|
||
// 根据排序方向比较
|
||
if sortDirection == "asc" {
|
||
return compareValues(a, b) < 0
|
||
}
|
||
return compareValues(a, b) > 0
|
||
})
|
||
}
|
||
|
||
// 计算返回范围
|
||
start := offset
|
||
end := offset + limit
|
||
if end > len(logsCopy) {
|
||
end = len(logsCopy)
|
||
}
|
||
if start >= len(logsCopy) {
|
||
return []QueryLog{} // 没有数据,返回空切片
|
||
}
|
||
|
||
return logsCopy[start:end]
|
||
}
|
||
|
||
// compareValues 比较两个值
|
||
func compareValues(a, b interface{}) int {
|
||
switch v1 := a.(type) {
|
||
case time.Time:
|
||
v2 := b.(time.Time)
|
||
if v1.Before(v2) {
|
||
return -1
|
||
}
|
||
if v1.After(v2) {
|
||
return 1
|
||
}
|
||
return 0
|
||
case string:
|
||
v2 := b.(string)
|
||
if v1 < v2 {
|
||
return -1
|
||
}
|
||
if v1 > v2 {
|
||
return 1
|
||
}
|
||
return 0
|
||
case int64:
|
||
v2 := b.(int64)
|
||
if v1 < v2 {
|
||
return -1
|
||
}
|
||
if v1 > v2 {
|
||
return 1
|
||
}
|
||
return 0
|
||
default:
|
||
return 0
|
||
}
|
||
}
|
||
|
||
// GetQueryLogsCount 获取查询日志总数
|
||
func (s *Server) GetQueryLogsCount() int {
|
||
s.queryLogsMutex.RLock()
|
||
defer s.queryLogsMutex.RUnlock()
|
||
return len(s.queryLogs)
|
||
}
|
||
|
||
// GetQueryStats 获取查询统计信息
|
||
func (s *Server) GetQueryStats() map[string]interface{} {
|
||
s.statsMutex.Lock()
|
||
defer s.statsMutex.Unlock()
|
||
|
||
// 计算统计数据
|
||
return map[string]interface{}{
|
||
"totalQueries": s.stats.Queries,
|
||
"blockedQueries": s.stats.Blocked,
|
||
"allowedQueries": s.stats.Allowed,
|
||
"errorQueries": s.stats.Errors,
|
||
"avgResponseTime": s.stats.AvgResponseTime,
|
||
"activeIPs": len(s.stats.SourceIPs),
|
||
}
|
||
}
|
||
|
||
// GetTopBlockedDomains 获取TOP屏蔽域名列表
|
||
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
|
||
s.blockedDomainsMutex.RLock()
|
||
defer s.blockedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||
for _, entry := range s.blockedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按计数排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].Count > domains[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetTopResolvedDomains 获取TOP解析域名
|
||
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
|
||
s.resolvedDomainsMutex.RLock()
|
||
defer s.resolvedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
|
||
for _, entry := range s.resolvedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按数量排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].Count > domains[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
|
||
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
|
||
s.blockedDomainsMutex.RLock()
|
||
defer s.blockedDomainsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
|
||
for _, entry := range s.blockedDomains {
|
||
domains = append(domains, *entry)
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(domains, func(i, j int) bool {
|
||
return domains[i].LastSeen > domains[j].LastSeen
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(domains) > limit {
|
||
return domains[:limit]
|
||
}
|
||
return domains
|
||
}
|
||
|
||
// GetTopClients 获取TOP客户端列表
|
||
func (s *Server) GetTopClients(limit int) []ClientStats {
|
||
s.clientStatsMutex.RLock()
|
||
defer s.clientStatsMutex.RUnlock()
|
||
|
||
// 转换为切片
|
||
clients := make([]ClientStats, 0, len(s.clientStats))
|
||
for _, entry := range s.clientStats {
|
||
clients = append(clients, *entry)
|
||
}
|
||
|
||
// 按请求次数排序
|
||
sort.Slice(clients, func(i, j int) bool {
|
||
return clients[i].Count > clients[j].Count
|
||
})
|
||
|
||
// 返回限制数量
|
||
if len(clients) > limit {
|
||
return clients[:limit]
|
||
}
|
||
return clients
|
||
}
|
||
|
||
// GetHourlyStats 获取每小时统计数据
|
||
func (s *Server) GetHourlyStats() map[string]int64 {
|
||
s.hourlyStatsMutex.RLock()
|
||
defer s.hourlyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.hourlyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetDailyStats 获取每日统计数据
|
||
func (s *Server) GetDailyStats() map[string]int64 {
|
||
s.dailyStatsMutex.RLock()
|
||
defer s.dailyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.dailyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// GetMonthlyStats 获取每月统计数据
|
||
func (s *Server) GetMonthlyStats() map[string]int64 {
|
||
s.monthlyStatsMutex.RLock()
|
||
defer s.monthlyStatsMutex.RUnlock()
|
||
|
||
// 返回副本
|
||
result := make(map[string]int64)
|
||
for k, v := range s.monthlyStats {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// isPrivateIP 检测IP地址是否为内网IP
|
||
func isPrivateIP(ip string) bool {
|
||
// 解析IP地址
|
||
parsedIP := net.ParseIP(ip)
|
||
if parsedIP == nil {
|
||
return false
|
||
}
|
||
|
||
// 检查IPv4内网地址
|
||
if ipv4 := parsedIP.To4(); ipv4 != nil {
|
||
// 10.0.0.0/8
|
||
if ipv4[0] == 10 {
|
||
return true
|
||
}
|
||
// 172.16.0.0/12
|
||
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
|
||
return true
|
||
}
|
||
// 192.168.0.0/16
|
||
if ipv4[0] == 192 && ipv4[1] == 168 {
|
||
return true
|
||
}
|
||
// 127.0.0.0/8 (localhost)
|
||
if ipv4[0] == 127 {
|
||
return true
|
||
}
|
||
// 169.254.0.0/16 (链路本地地址)
|
||
if ipv4[0] == 169 && ipv4[1] == 254 {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 检查IPv6内网地址
|
||
// ::1/128 (localhost)
|
||
if parsedIP.IsLoopback() {
|
||
return true
|
||
}
|
||
// fc00::/7 (唯一本地地址)
|
||
if parsedIP[0]&0xfc == 0xfc {
|
||
return true
|
||
}
|
||
// fe80::/10 (链路本地地址)
|
||
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// loadStatsData 从文件加载统计数据
|
||
func (s *Server) loadStatsData() {
|
||
// 检查文件是否存在
|
||
data, err := ioutil.ReadFile("data/stats.json")
|
||
if err != nil {
|
||
if !os.IsNotExist(err) {
|
||
logger.Error("读取统计数据文件失败", "error", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
var statsData StatsData
|
||
err = json.Unmarshal(data, &statsData)
|
||
if err != nil {
|
||
logger.Error("解析统计数据失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 恢复统计数据
|
||
s.statsMutex.Lock()
|
||
if statsData.Stats != nil {
|
||
// 只恢复有效数据,避免破坏统计关系
|
||
s.stats.Queries += statsData.Stats.Queries
|
||
s.stats.Blocked += statsData.Stats.Blocked
|
||
s.stats.Allowed += statsData.Stats.Allowed
|
||
s.stats.Errors += statsData.Stats.Errors
|
||
s.stats.TotalResponseTime += statsData.Stats.TotalResponseTime
|
||
s.stats.DNSSECQueries += statsData.Stats.DNSSECQueries
|
||
s.stats.DNSSECSuccess += statsData.Stats.DNSSECSuccess
|
||
s.stats.DNSSECFailed += statsData.Stats.DNSSECFailed
|
||
|
||
// 重新计算平均响应时间,确保一致性
|
||
if s.stats.Queries > 0 {
|
||
s.stats.AvgResponseTime = float64(s.stats.TotalResponseTime) / float64(s.stats.Queries)
|
||
// 限制平均响应时间的范围,避免显示异常大的值
|
||
if s.stats.AvgResponseTime > 60000 {
|
||
s.stats.AvgResponseTime = 60000
|
||
}
|
||
}
|
||
|
||
// 合并查询类型统计
|
||
for k, v := range statsData.Stats.QueryTypes {
|
||
s.stats.QueryTypes[k] += v
|
||
}
|
||
|
||
// 合并来源IP统计
|
||
for ip := range statsData.Stats.SourceIPs {
|
||
s.stats.SourceIPs[ip] = true
|
||
}
|
||
|
||
// 确保使用当前配置中的EnableDNSSEC值
|
||
s.stats.DNSSECEnabled = s.config.EnableDNSSEC
|
||
}
|
||
s.statsMutex.Unlock()
|
||
|
||
s.blockedDomainsMutex.Lock()
|
||
if statsData.BlockedDomains != nil {
|
||
s.blockedDomains = statsData.BlockedDomains
|
||
}
|
||
s.blockedDomainsMutex.Unlock()
|
||
|
||
s.resolvedDomainsMutex.Lock()
|
||
if statsData.ResolvedDomains != nil {
|
||
s.resolvedDomains = statsData.ResolvedDomains
|
||
}
|
||
s.resolvedDomainsMutex.Unlock()
|
||
|
||
s.hourlyStatsMutex.Lock()
|
||
if statsData.HourlyStats != nil {
|
||
s.hourlyStats = statsData.HourlyStats
|
||
}
|
||
s.hourlyStatsMutex.Unlock()
|
||
|
||
s.dailyStatsMutex.Lock()
|
||
if statsData.DailyStats != nil {
|
||
s.dailyStats = statsData.DailyStats
|
||
}
|
||
s.dailyStatsMutex.Unlock()
|
||
|
||
s.monthlyStatsMutex.Lock()
|
||
if statsData.MonthlyStats != nil {
|
||
s.monthlyStats = statsData.MonthlyStats
|
||
}
|
||
s.monthlyStatsMutex.Unlock()
|
||
|
||
// 加载客户端统计数据
|
||
s.clientStatsMutex.Lock()
|
||
if statsData.ClientStats != nil {
|
||
s.clientStats = statsData.ClientStats
|
||
}
|
||
s.clientStatsMutex.Unlock()
|
||
|
||
logger.Info("统计数据加载成功")
|
||
|
||
// 加载查询日志
|
||
s.loadQueryLogs()
|
||
}
|
||
|
||
// loadQueryLogs 从文件加载查询日志
|
||
func (s *Server) loadQueryLogs() {
|
||
// 获取绝对路径
|
||
statsFilePath, err := filepath.Abs("data/stats.json")
|
||
if err != nil {
|
||
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
|
||
return
|
||
}
|
||
|
||
// 构建查询日志文件路径
|
||
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
|
||
|
||
// 检查文件是否存在
|
||
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
|
||
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
|
||
return
|
||
}
|
||
|
||
// 读取文件内容
|
||
data, err := ioutil.ReadFile(queryLogPath)
|
||
if err != nil {
|
||
logger.Error("读取查询日志文件失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 解析数据
|
||
var logs []QueryLog
|
||
err = json.Unmarshal(data, &logs)
|
||
if err != nil {
|
||
logger.Error("解析查询日志失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 更新查询日志
|
||
s.queryLogsMutex.Lock()
|
||
s.queryLogs = logs
|
||
// 确保日志数量不超过限制
|
||
if len(s.queryLogs) > s.maxQueryLogs {
|
||
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
|
||
}
|
||
s.queryLogsMutex.Unlock()
|
||
|
||
logger.Info("查询日志加载成功", "count", len(logs))
|
||
}
|
||
|
||
// processLogs 异步处理日志记录
|
||
func (s *Server) processLogs() {
|
||
for {
|
||
select {
|
||
case logEntry, ok := <-s.logChannel:
|
||
if !ok {
|
||
// 通道关闭,退出循环
|
||
return
|
||
}
|
||
|
||
// 加锁保护queryLogs
|
||
s.queryLogsMutex.Lock()
|
||
|
||
// 如果日志数量超过最大限制,删除最旧的日志
|
||
if len(s.queryLogs) >= s.maxQueryLogs {
|
||
// 保留最新的s.maxQueryLogs条日志
|
||
newLogs := make([]QueryLog, 0, s.maxQueryLogs)
|
||
// 复制最新的日志到新切片
|
||
for i := len(s.queryLogs) - s.maxQueryLogs + 1; i < len(s.queryLogs); i++ {
|
||
newLogs = append(newLogs, s.queryLogs[i])
|
||
}
|
||
// 添加新日志
|
||
newLogs = append(newLogs, logEntry)
|
||
// 替换原有日志
|
||
s.queryLogs = newLogs
|
||
} else {
|
||
// 直接添加新日志
|
||
s.queryLogs = append(s.queryLogs, logEntry)
|
||
}
|
||
|
||
// 解锁
|
||
s.queryLogsMutex.Unlock()
|
||
|
||
case <-s.ctx.Done():
|
||
// 上下文取消,退出循环
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// saveStatsData 保存统计数据到文件
|
||
func (s *Server) saveStatsData() {
|
||
// 获取绝对路径以避免工作目录问题
|
||
statsFilePath, err := filepath.Abs("data/stats.json")
|
||
if err != nil {
|
||
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
|
||
return
|
||
}
|
||
|
||
// 创建数据目录
|
||
statsDir := filepath.Dir(statsFilePath)
|
||
err = os.MkdirAll(statsDir, 0755)
|
||
if err != nil {
|
||
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
|
||
return
|
||
}
|
||
|
||
// 收集所有统计数据
|
||
statsData := &StatsData{
|
||
Stats: s.GetStats(),
|
||
LastSaved: time.Now(),
|
||
}
|
||
|
||
// 复制域名数据
|
||
s.blockedDomainsMutex.RLock()
|
||
statsData.BlockedDomains = make(map[string]*BlockedDomain)
|
||
for k, v := range s.blockedDomains {
|
||
statsData.BlockedDomains[k] = v
|
||
}
|
||
s.blockedDomainsMutex.RUnlock()
|
||
|
||
s.resolvedDomainsMutex.RLock()
|
||
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
|
||
for k, v := range s.resolvedDomains {
|
||
statsData.ResolvedDomains[k] = v
|
||
}
|
||
s.resolvedDomainsMutex.RUnlock()
|
||
|
||
s.hourlyStatsMutex.RLock()
|
||
statsData.HourlyStats = make(map[string]int64)
|
||
for k, v := range s.hourlyStats {
|
||
statsData.HourlyStats[k] = v
|
||
}
|
||
s.hourlyStatsMutex.RUnlock()
|
||
|
||
s.dailyStatsMutex.RLock()
|
||
statsData.DailyStats = make(map[string]int64)
|
||
for k, v := range s.dailyStats {
|
||
statsData.DailyStats[k] = v
|
||
}
|
||
s.dailyStatsMutex.RUnlock()
|
||
|
||
s.monthlyStatsMutex.RLock()
|
||
statsData.MonthlyStats = make(map[string]int64)
|
||
for k, v := range s.monthlyStats {
|
||
statsData.MonthlyStats[k] = v
|
||
}
|
||
s.monthlyStatsMutex.RUnlock()
|
||
|
||
// 复制客户端统计数据
|
||
s.clientStatsMutex.RLock()
|
||
statsData.ClientStats = make(map[string]*ClientStats)
|
||
for k, v := range s.clientStats {
|
||
statsData.ClientStats[k] = v
|
||
}
|
||
s.clientStatsMutex.RUnlock()
|
||
|
||
// 序列化数据
|
||
jsonData, err := json.MarshalIndent(statsData, "", " ")
|
||
if err != nil {
|
||
logger.Error("序列化统计数据失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 写入文件
|
||
err = os.WriteFile(statsFilePath, jsonData, 0644)
|
||
if err != nil {
|
||
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
|
||
return
|
||
}
|
||
|
||
logger.Info("统计数据保存成功", "file", statsFilePath)
|
||
|
||
// 保存查询日志到文件
|
||
s.saveQueryLogs(statsDir)
|
||
}
|
||
|
||
// saveQueryLogs 保存查询日志到文件
|
||
func (s *Server) saveQueryLogs(dataDir string) {
|
||
// 构建查询日志文件路径
|
||
queryLogPath := filepath.Join(dataDir, "querylog.json")
|
||
|
||
// 获取查询日志数据
|
||
s.queryLogsMutex.RLock()
|
||
logsCopy := make([]QueryLog, len(s.queryLogs))
|
||
copy(logsCopy, s.queryLogs)
|
||
s.queryLogsMutex.RUnlock()
|
||
|
||
// 序列化数据
|
||
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
|
||
if err != nil {
|
||
logger.Error("序列化查询日志失败", "error", err)
|
||
return
|
||
}
|
||
|
||
// 写入文件
|
||
err = os.WriteFile(queryLogPath, jsonData, 0644)
|
||
if err != nil {
|
||
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
|
||
return
|
||
}
|
||
|
||
logger.Info("查询日志保存成功", "file", queryLogPath)
|
||
}
|
||
|
||
// startCpuUsageMonitor 启动CPU使用率监控
|
||
func (s *Server) startCpuUsageMonitor() {
|
||
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
|
||
defer ticker.Stop()
|
||
|
||
// 初始化
|
||
var memStats runtime.MemStats
|
||
runtime.ReadMemStats(&memStats)
|
||
|
||
// 存储上一次的CPU时间统计
|
||
var prevIdle, prevTotal uint64
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
// 获取真实的系统级CPU使用率
|
||
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
|
||
if err != nil {
|
||
// 如果获取失败,使用默认值
|
||
cpuUsage = 0.0
|
||
logger.Error("获取系统CPU使用率失败", "error", err)
|
||
}
|
||
|
||
s.updateStats(func(stats *Stats) {
|
||
stats.CpuUsage = cpuUsage
|
||
})
|
||
case <-s.ctx.Done():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// getSystemCpuUsage 获取系统CPU使用率
|
||
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
|
||
// 读取/proc/stat文件获取CPU统计信息
|
||
file, err := os.Open("/proc/stat")
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
defer file.Close()
|
||
|
||
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
|
||
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
|
||
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// 计算总的CPU时间
|
||
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
|
||
idle := cpuIdle + cpuIowait
|
||
|
||
// 第一次调用时,只初始化值,不计算使用率
|
||
if *prevTotal == 0 || *prevIdle == 0 {
|
||
*prevIdle = idle
|
||
*prevTotal = total
|
||
return 0, nil
|
||
}
|
||
|
||
// 计算CPU使用率
|
||
idleDelta := idle - *prevIdle
|
||
totalDelta := total - *prevTotal
|
||
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
|
||
|
||
// 更新上一次的值
|
||
*prevIdle = idle
|
||
*prevTotal = total
|
||
|
||
return utilization, nil
|
||
}
|
||
|
||
// startAutoSave 启动自动保存功能
|
||
func (s *Server) startAutoSave() {
|
||
if s.config.SaveInterval <= 0 {
|
||
return
|
||
}
|
||
|
||
// 初始化定时器
|
||
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
|
||
defer s.saveTicker.Stop()
|
||
|
||
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
|
||
|
||
// 定期保存数据
|
||
for {
|
||
select {
|
||
case <-s.saveTicker.C:
|
||
s.saveStatsData()
|
||
case <-s.saveDone:
|
||
return
|
||
}
|
||
}
|
||
}
|