Files
dns-server/dns/server.go
2025-12-17 22:43:31 +08:00

2597 lines
70 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen time.Time
DNSSEC bool // 是否使用了DNSSEC
}
// ClientStats 客户端统计
type ClientStats struct {
IP string
Count int64
LastSeen time.Time
}
// IPGeolocation IP地理位置信息
type IPGeolocation struct {
Country string `json:"country"` // 国家
City string `json:"city"` // 城市
Expiry time.Time `json:"expiry"` // 缓存过期时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time // 查询时间
ClientIP string // 客户端IP
Location string // IP地理位置国家 城市)
Domain string // 查询域名
QueryType string // 查询类型
ResponseTime int64 // 响应时间(ms)
Result string // 查询结果allowed, blocked, error
BlockRule string // 屏蔽规则(如果被屏蔽)
BlockType string // 屏蔽类型(如果被屏蔽)
FromCache bool // 是否来自缓存
DNSSEC bool // 是否使用了DNSSEC
EDNS bool // 是否使用了EDNS
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
ClientStats map[string]*ClientStats `json:"clientStats"`
HourlyStats map[string]int64 `json:"hourlyStats"`
DailyStats map[string]int64 `json:"dailyStats"`
MonthlyStats map[string]int64 `json:"monthlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// ServerStats 服务器统计信息
type ServerStats struct {
SuccessCount int64 // 成功查询次数
FailureCount int64 // 失败查询次数
LastResponse time.Time // 最后响应时间
ResponseTime time.Duration // 平均响应时间
ConnectionSpeed time.Duration // TCP连接速度
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
server *dns.Server
tcpServer *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
clientStatsMutex sync.RWMutex
clientStats map[string]*ClientStats // 用于记录客户端统计
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
dailyStatsMutex sync.RWMutex
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表
maxQueryLogs int // 最大保存日志数量
saveTicker *time.Ticker // 用于定时保存数据
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// IP地理位置缓存
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
ipGeolocationCacheTTL time.Duration // 缓存有效期
// DNS查询缓存
dnsCache *DNSCache // DNS响应缓存
// 域名DNSSEC状态映射表
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
// 上游服务器状态跟踪
serverStats map[string]*ServerStats // 服务器地址到状态的映射
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
AvgResponseTime float64 // 平均响应时间(ms)
TotalResponseTime int64 // 总响应时间
QueryTypes map[string]int64 // 查询类型统计
SourceIPs map[string]bool // 活跃来源IP
CpuUsage float64 // CPU使用率(%)
DNSSECQueries int64 // DNSSEC查询总数
DNSSECSuccess int64 // DNSSEC验证成功数
DNSSECFailed int64 // DNSSEC验证失败数
DNSSECEnabled bool // 是否启用了DNSSEC
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值分钟
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
resolver: &dns.Client{
Net: "udp",
Timeout: time.Duration(config.Timeout) * time.Millisecond,
UDPSize: 4096, // 增加UDP缓冲区大小支持更大的DNSSEC响应
},
ctx: ctx,
cancel: cancel,
startTime: time.Now(), // 记录服务器启动时间
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
AvgResponseTime: 0,
TotalResponseTime: 0,
QueryTypes: make(map[string]int64),
SourceIPs: make(map[string]bool),
CpuUsage: 0,
DNSSECQueries: 0,
DNSSECSuccess: 0,
DNSSECFailed: 0,
DNSSECEnabled: config.EnableDNSSEC,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
clientStats: make(map[string]*ClientStats),
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// IP地理位置缓存初始化
ipGeolocationCache: make(map[string]*IPGeolocation),
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
// DNS查询缓存初始化
dnsCache: NewDNSCache(cacheTTL),
// 初始化域名DNSSEC状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel
// 重新初始化saveDone通道
s.saveDone = make(chan struct{})
// 重置stopped标志
s.stoppedMutex.Lock()
s.stopped = false
s.stoppedMutex.Unlock()
// 更新服务器启动时间
s.startTime = time.Now()
s.server = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 保存TCP服务器实例以便在Stop方法中关闭
s.tcpServer = &dns.Server{
Addr: fmt.Sprintf(":%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动CPU使用率监控
go s.startCpuUsageMonitor()
// 启动自动保存功能
go s.startAutoSave()
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := s.tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
s.stoppedMutex.Lock()
if s.stopped {
s.stoppedMutex.Unlock()
return // 服务器已经停止,直接返回
}
// 标记服务器为已停止状态
s.stopped = true
s.stoppedMutex.Unlock()
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
if s.tcpServer != nil {
s.tcpServer.Shutdown()
}
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
// 获取来源IP
sourceIP := w.RemoteAddr().String()
// 提取IP地址部分去掉端口
if strings.HasPrefix(sourceIP, "[") {
// IPv6地址格式: [::1]:53
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
sourceIP = sourceIP[1:idx] // 去掉方括号
}
} else {
// IPv4地址格式: 127.0.0.1:53
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
sourceIP = sourceIP[:idx]
}
}
// 更新来源IP统计
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
stats.SourceIPs[sourceIP] = true
})
// 更新客户端统计
s.updateClientStats(sourceIP)
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
// 缓存命中响应时间设为0ms
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true)
return
}
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
s.handleHostsResponse(w, r, ip)
// 缓存命中响应时间设为0ms
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, false, true)
return
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(domain) {
// 获取屏蔽详情
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
blockRule, _ := blockDetails["blockRule"].(string)
blockType, _ := blockDetails["blockRuleType"].(string)
s.handleBlockedResponse(w, r, domain)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true)
return
}
// 检查缓存中是否有响应优先查找带DNSSEC的缓存项
var cachedResponse *dns.Msg
var found bool
var cachedDNSSEC bool
// 1. 首先检查是否有普通缓存项
if tempResponse, tempFound := s.dnsCache.Get(r.Question[0].Name, qType); tempFound {
cachedResponse = tempResponse
found = tempFound
cachedDNSSEC = s.hasDNSSECRecords(tempResponse)
}
// 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项
// 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录
// 这里可以进一步优化比如在缓存中标记DNSSEC状态快速查找
if s.config.EnableDNSSEC && !cachedDNSSEC {
// 目前的缓存实现不支持按DNSSEC状态查找所以这里暂时跳过
// 后续可以考虑改进缓存实现添加DNSSEC状态标记
}
if found {
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range cachedResponseCopy.Extra {
if cachedResponseCopy.Extra[i] == respOpt {
cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
w.WriteMsg(cachedResponseCopy)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 如果缓存响应包含DNSSEC记录更新DNSSEC查询计数
if cachedDNSSEC {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
// 缓存响应视为DNSSEC成功
stats.DNSSECSuccess++
})
}
// 添加查询日志 - 标记为缓存
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true)
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC)
return
}
// 缓存未命中转发到上游DNS服务器
response, rtt := s.forwardDNSRequestWithCache(r, domain)
if response != nil {
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := response.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range response.Extra {
if response.Extra[i] == respOpt {
response.Extra = append(response.Extra[:i], response.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 写入响应给客户端
w.WriteMsg(response)
}
// 使用上游服务器的实际响应时间(转换为毫秒)
responseTime := int64(rtt.Milliseconds())
// 如果rtt为0查询失败则使用本地计算的时间
if responseTime == 0 {
responseTime = time.Since(startTime).Milliseconds()
}
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 检查响应是否包含DNSSEC记录并验证结果
responseDNSSEC := false
if response != nil {
// 使用hasDNSSECRecords函数检查是否包含DNSSEC记录
responseDNSSEC = s.hasDNSSECRecords(response)
// 检查AD标志确认DNSSEC验证是否成功
if response.AuthenticatedData {
responseDNSSEC = true
}
// 更新域名的DNSSEC状态
if responseDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
}
}
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
s.dnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
}
// 添加查询日志 - 标记为实时
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true)
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(customBlockIP)
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
// serverResponse 用于存储服务器响应的结构体
type serverResponse struct {
response *dns.Msg
rtt time.Duration
server string
error error
}
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration) {
// 始终支持EDNS
var udpSize uint16 = 4096
var doFlag bool = s.config.EnableDNSSEC
// 检查客户端请求是否包含EDNS记录
if opt := r.IsEdns0(); opt != nil {
// 保留客户端的UDP缓冲区大小
udpSize = opt.UDPSize()
// 移除现有的EDNS记录以便重新添加
for i := range r.Extra {
if r.Extra[i] == opt {
r.Extra = append(r.Extra[:i], r.Extra[i+1:]...)
break
}
}
}
// 添加EDNS记录设置适当的UDPSize和DO标志
r.SetEdns0(udpSize, doFlag)
// DNSSEC专用服务器列表从配置中获取
dnssecServers := s.config.DNSSECUpstreamDNS
// 1. 首先尝试所有配置的上游DNS服务器
var bestResponse *dns.Msg
var bestRtt time.Duration
var hasBestResponse bool
var hasDNSSECResponse bool
var backupResponse *dns.Msg
var backupRtt time.Duration
var hasBackup bool
// 根据查询模式处理请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式 - 优化版:添加超时处理和快速响应返回
responses := make(chan serverResponse, len(s.config.UpstreamDNS))
var wg sync.WaitGroup
// 超时上下文
timeoutCtx, cancel := context.WithTimeout(s.ctx, time.Duration(s.config.Timeout)*time.Millisecond)
defer cancel()
// 向所有上游服务器并行发送请求
for _, upstream := range s.config.UpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
select {
case responses <- serverResponse{response, rtt, server, err}:
// 成功发送响应
case <-timeoutCtx.Done():
// 超时,忽略此响应
logger.Debug("并行请求超时", "server", server, "domain", domain)
return
}
}(upstream)
}
// 等待所有请求完成或超时
go func() {
wg.Wait()
close(responses)
}()
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 设置递归可用标志
resp.response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
if s.config.EnableDNSSEC && containsDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
// 设置AD标志Authenticated Data
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 如果响应成功根据DNSSEC状态选择最佳响应
if resp.response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
} else if !hasBestResponse {
// 没有带DNSSEC的响应时保存第一个成功响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
// 保存为备选响应
if !hasBackup {
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
case "loadbalance":
// 负载均衡模式 - 使用加权随机选择算法
// 1. 选择一个加权随机服务器
selectedServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
if selectedServer != "" {
response, rtt, err := s.resolver.Exchange(r, selectedServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
if s.config.EnableDNSSEC && containsDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 如果响应成功根据DNSSEC状态选择最佳响应
if response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt)
} else {
// 没有带DNSSEC的响应时保存成功响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt)
}
// 保存为备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedServer, false, 0)
}
}
case "fastest-ip":
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
// 1. 选择最快的服务器
fastestServer := s.selectFastestServer(s.config.UpstreamDNS)
if fastestServer != "" {
response, rtt, err := s.resolver.Exchange(r, fastestServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
if s.config.EnableDNSSEC && containsDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 如果响应成功根据DNSSEC状态选择最佳响应
if response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
} else {
// 没有带DNSSEC的响应时保存成功响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
}
// 保存为备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
}
}
default:
// 默认使用并行请求模式
responses := make(chan serverResponse, len(s.config.UpstreamDNS))
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range s.config.UpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
response, rtt, err := s.resolver.Exchange(r, server)
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
// 等待所有请求完成
go func() {
wg.Wait()
close(responses)
}()
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 设置递归可用标志
resp.response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
if s.config.EnableDNSSEC && containsDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
// 设置AD标志Authenticated Data
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 如果响应成功根据DNSSEC状态选择最佳响应
if resp.response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
} else if !hasBestResponse {
// 没有带DNSSEC的响应时保存第一个成功响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
// 保存为备选响应
if !hasBackup {
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
}
}
}
// 2. 当启用DNSSEC且没有找到带DNSSEC的响应时向DNSSEC专用服务器发送请求
if s.config.EnableDNSSEC && !hasDNSSECResponse {
logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain)
// 增加DNSSEC查询计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
})
// 根据查询模式处理DNSSEC服务器请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式 - 优化版:添加超时处理和服务器统计
responses := make(chan serverResponse, len(dnssecServers))
var wg sync.WaitGroup
// 超时上下文
timeoutCtx, cancel := context.WithTimeout(s.ctx, time.Duration(s.config.Timeout)*time.Millisecond)
defer cancel()
// 向所有DNSSEC服务器并行发送请求
for _, dnssecServer := range dnssecServers {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
select {
case responses <- serverResponse{response, rtt, server, err}:
// 成功发送响应
case <-timeoutCtx.Done():
// 超时,忽略此响应
logger.Debug("DNSSEC并行请求超时", "server", server, "domain", domain)
return
}
}(dnssecServer)
}
// 等待所有请求完成或超时
go func() {
wg.Wait()
close(responses)
}()
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 设置递归可用标志
resp.response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
if resp.response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
// 设置AD标志Authenticated Data
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = resp.response
bestRtt = resp.rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", resp.server, "rtt", resp.rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
case "loadbalance":
// 负载均衡模式 - 使用加权随机选择算法
// 1. 选择一个加权随机DNSSEC服务器
selectedServer := s.selectWeightedRandomServer(dnssecServers)
if selectedServer != "" {
response, rtt, err := s.resolver.Exchange(r, selectedServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", selectedServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedServer, false, 0)
}
}
case "fastest-ip":
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快DNSSEC服务器
// 1. 选择最快的DNSSEC服务器
fastestServer := s.selectFastestServer(dnssecServers)
if fastestServer != "" {
response, rtt, err := s.resolver.Exchange(r, fastestServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", fastestServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
}
}
default:
// 默认使用顺序请求模式
for _, dnssecServer := range dnssecServers {
response, rtt, err := s.resolver.Exchange(r, dnssecServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(dnssecServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", dnssecServer, "rtt", rtt)
break // 找到带DNSSEC的响应立即返回
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(dnssecServer, false, 0)
}
}
}
}
// 3. 返回最佳响应
if hasBestResponse {
// 检查最佳响应是否包含DNSSEC记录
bestHasDNSSEC := s.hasDNSSECRecords(bestResponse)
// 如果启用了DNSSEC且最佳响应不包含DNSSEC记录使用upstreamDNS的解析结果
if s.config.EnableDNSSEC && !bestHasDNSSEC {
logger.Debug("最佳响应不包含DNSSEC记录使用upstreamDNS的解析结果", "domain", domain)
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
if bestHasDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
}
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return bestResponse, bestRtt
}
// 如果有备选响应,返回该响应
if hasBackup {
logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return backupResponse, backupRtt
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 更新被屏蔽域名计数
s.blockedDomainsMutex.Lock()
defer s.blockedDomainsMutex.Unlock()
if entry, exists := s.blockedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
// 更新统计数据
now := time.Now()
// 更新小时统计
hourKey := now.Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
// 更新每日统计
dayKey := now.Format("2006-01-02")
s.dailyStatsMutex.Lock()
s.dailyStats[dayKey]++
s.dailyStatsMutex.Unlock()
// 更新每月统计
monthKey := now.Format("2006-01")
s.monthlyStatsMutex.Lock()
s.monthlyStats[monthKey]++
s.monthlyStatsMutex.Unlock()
}
// updateClientStats 更新客户端统计
func (s *Server) updateClientStats(ip string) {
s.clientStatsMutex.Lock()
defer s.clientStatsMutex.Unlock()
if entry, exists := s.clientStats[ip]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.clientStats[ip] = &ClientStats{
IP: ip,
Count: 1,
LastSeen: time.Now(),
}
}
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func (s *Server) hasDNSSECRecords(response *dns.Msg) bool {
// 检查响应中是否包含DNSSEC相关记录DNSKEY、RRSIG、DS、NSEC、NSEC3等
for _, rr := range response.Answer {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Ns {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Extra {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
return false
}
// verifyDNSSEC 验证DNSSEC签名
func (s *Server) verifyDNSSEC(response *dns.Msg) bool {
// 提取DNSKEY和RRSIG记录
dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY
rrsigs := make([]*dns.RRSIG, 0)
// 从响应中提取所有DNSKEY和RRSIG记录
for _, rr := range response.Answer {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
for _, rr := range response.Ns {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
for _, rr := range response.Extra {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
// 如果没有RRSIG记录验证失败
if len(rrsigs) == 0 {
return false
}
// 验证所有RRSIG记录
signatureValid := true
for _, rrsig := range rrsigs {
// 查找对应的DNSKEY
dnskey, exists := dnskeys[rrsig.KeyTag]
if !exists {
logger.Warn("DNSSEC验证失败找不到对应的DNSKEY", "key_tag", rrsig.KeyTag)
signatureValid = false
continue
}
// 收集需要验证的记录集
rrset := make([]dns.RR, 0)
for _, rr := range response.Answer {
if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered {
rrset = append(rrset, rr)
}
}
for _, rr := range response.Ns {
if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered {
rrset = append(rrset, rr)
}
}
// 验证签名
if len(rrset) > 0 {
err := rrsig.Verify(dnskey, rrset)
if err != nil {
logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag)
signatureValid = false
} else {
logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag)
}
}
}
return signatureValid
}
// updateDomainDNSSECStatus 更新域名的DNSSEC状态
func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) {
// 确保域名是小写
domain = strings.ToLower(domain)
// 更新域名的DNSSEC状态
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
// 更新resolvedDomains中的DNSSEC状态
if entry, exists := s.resolvedDomains[domain]; exists {
entry.DNSSEC = dnssec
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
DNSSEC: dnssec,
}
}
// 更新domainDNSSECStatus映射
s.domainDNSSECStatus[domain] = dnssec
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
if entry, exists := s.resolvedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
DNSSEC: false,
}
}
}
// getServerStats 获取服务器统计信息,如果不存在则创建
func (s *Server) getServerStats(server string) *ServerStats {
s.serverStatsMutex.RLock()
stats, exists := s.serverStats[server]
s.serverStatsMutex.RUnlock()
if !exists {
// 创建新的服务器统计信息
stats = &ServerStats{
SuccessCount: 0,
FailureCount: 0,
LastResponse: time.Now(),
ResponseTime: 0,
ConnectionSpeed: 0,
}
// 加锁更新服务器统计信息
s.serverStatsMutex.Lock()
s.serverStats[server] = stats
s.serverStatsMutex.Unlock()
}
return stats
}
// updateServerStats 更新服务器统计信息
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
defer s.serverStatsMutex.Unlock()
// 更新统计信息
stats.LastResponse = time.Now()
if success {
stats.SuccessCount++
// 更新平均响应时间(简单移动平均)
// 将所有值转换为纳秒进行计算然后再转换回Duration
if stats.SuccessCount == 1 {
// 第一次成功,直接使用当前响应时间
stats.ResponseTime = rtt
} else {
// 使用纳秒进行计算以避免类型不匹配
prevTotal := stats.ResponseTime.Nanoseconds() * (stats.SuccessCount - 1)
newTotal := prevTotal + rtt.Nanoseconds()
stats.ResponseTime = time.Duration(newTotal / stats.SuccessCount)
}
} else {
stats.FailureCount++
}
}
// selectWeightedRandomServer 加权随机选择服务器
func (s *Server) selectWeightedRandomServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 计算每个服务器的权重
type serverWeight struct {
server string
weight int64
}
var totalWeight int64
weights := make([]serverWeight, 0, len(servers))
for _, server := range servers {
stats := s.getServerStats(server)
// 计算权重:成功次数 - 失败次数 * 2失败权重更高
// 确保权重至少为1
weight := stats.SuccessCount - stats.FailureCount*2
if weight < 1 {
weight = 1
}
weights = append(weights, serverWeight{server, weight})
totalWeight += weight
}
// 随机选择一个权重
random := time.Now().UnixNano() % totalWeight
if random < 0 {
random += totalWeight
}
// 选择对应的服务器
var currentWeight int64
for _, sw := range weights {
currentWeight += sw.weight
if random < currentWeight {
return sw.server
}
}
// 兜底返回第一个服务器
return servers[0]
}
// measureServerSpeed 测量服务器TCP连接速度
func (s *Server) measureServerSpeed(server string) time.Duration {
// 提取服务器地址和端口
addr := server
if !strings.Contains(server, ":") {
addr = server + ":53"
}
// 测量TCP连接时间
startTime := time.Now()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
// 连接失败,返回最大持续时间
return 2 * time.Second
}
defer conn.Close()
// 计算连接建立时间
connTime := time.Since(startTime)
// 更新服务器连接速度
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
// 使用指数移动平均更新连接速度
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
s.serverStatsMutex.Unlock()
return connTime
}
// selectFastestServer 选择连接速度最快的服务器
func (s *Server) selectFastestServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 并行测量所有服务器的速度
type speedResult struct {
server string
speed time.Duration
}
results := make(chan speedResult, len(servers))
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(srv string) {
defer wg.Done()
speed := s.measureServerSpeed(srv)
results <- speedResult{srv, speed}
}(server)
}
// 等待所有测量完成
go func() {
wg.Wait()
close(results)
}()
// 找出最快的服务器
var fastestServer string
var fastestSpeed time.Duration = 2 * time.Second
for result := range results {
if result.speed < fastestSpeed {
fastestSpeed = result.speed
fastestServer = result.server
}
}
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
if fastestServer == "" {
fastestServer = servers[0]
}
return fastestServer
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool) {
// 获取IP地理位置
location := s.getIpGeolocation(clientIP)
// 创建日志记录
log := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Location: location,
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
DNSSEC: dnssec,
EDNS: edns,
}
// 添加到日志列表
s.queryLogsMutex.Lock()
defer s.queryLogsMutex.Unlock()
// 插入到列表开头
s.queryLogs = append([]QueryLog{log}, s.queryLogs...)
// 限制日志数量
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
}
// GetStartTime 获取服务器启动时间
func (s *Server) GetStartTime() time.Time {
return s.startTime
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 复制查询类型统计
queryTypesCopy := make(map[string]int64)
for k, v := range s.stats.QueryTypes {
queryTypesCopy[k] = v
}
// 复制来源IP统计
sourceIPsCopy := make(map[string]bool)
for ip := range s.stats.SourceIPs {
sourceIPsCopy[ip] = true
}
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
AvgResponseTime: s.stats.AvgResponseTime,
TotalResponseTime: s.stats.TotalResponseTime,
QueryTypes: queryTypesCopy,
SourceIPs: sourceIPsCopy,
CpuUsage: s.stats.CpuUsage,
DNSSECQueries: s.stats.DNSSECQueries,
DNSSECSuccess: s.stats.DNSSECSuccess,
DNSSECFailed: s.stats.DNSSECFailed,
DNSSECEnabled: s.stats.DNSSECEnabled,
}
}
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
// 确保偏移量和限制值合理
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 100 // 默认返回100条日志
}
// 创建日志副本用于过滤和排序
var logsCopy []QueryLog
// 先过滤日志
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
logsCopy = append(logsCopy, log)
}
// 排序日志
if sortField != "" {
sort.Slice(logsCopy, func(i, j int) bool {
var a, b interface{}
switch sortField {
case "time":
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
case "clientIp":
a = logsCopy[i].ClientIP
b = logsCopy[j].ClientIP
case "domain":
a = logsCopy[i].Domain
b = logsCopy[j].Domain
case "responseTime":
a = logsCopy[i].ResponseTime
b = logsCopy[j].ResponseTime
case "blockRule":
a = logsCopy[i].BlockRule
b = logsCopy[j].BlockRule
default:
// 默认按时间排序
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
}
// 根据排序方向比较
if sortDirection == "asc" {
return compareValues(a, b) < 0
}
return compareValues(a, b) > 0
})
}
// 计算返回范围
start := offset
end := offset + limit
if end > len(logsCopy) {
end = len(logsCopy)
}
if start >= len(logsCopy) {
return []QueryLog{} // 没有数据,返回空切片
}
return logsCopy[start:end]
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
switch v1 := a.(type) {
case time.Time:
v2 := b.(time.Time)
if v1.Before(v2) {
return -1
}
if v1.After(v2) {
return 1
}
return 0
case string:
v2 := b.(string)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
case int64:
v2 := b.(int64)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
default:
return 0
}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
return len(s.queryLogs)
}
// GetQueryStats 获取查询统计信息
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 计算统计数据
return map[string]interface{}{
"totalQueries": s.stats.Queries,
"blockedQueries": s.stats.Blocked,
"allowedQueries": s.stats.Allowed,
"errorQueries": s.stats.Errors,
"avgResponseTime": s.stats.AvgResponseTime,
"activeIPs": len(s.stats.SourceIPs),
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
domains = append(domains, *entry)
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen.After(domains[j].LastSeen)
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopClients 获取TOP客户端列表
func (s *Server) GetTopClients(limit int) []ClientStats {
s.clientStatsMutex.RLock()
defer s.clientStatsMutex.RUnlock()
// 转换为切片
clients := make([]ClientStats, 0, len(s.clientStats))
for _, entry := range s.clientStats {
clients = append(clients, *entry)
}
// 按请求次数排序
sort.Slice(clients, func(i, j int) bool {
return clients[i].Count > clients[j].Count
})
// 返回限制数量
if len(clients) > limit {
return clients[:limit]
}
return clients
}
// GetHourlyStats 获取每小时统计数据
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}
// GetDailyStats 获取每日统计数据
func (s *Server) GetDailyStats() map[string]int64 {
s.dailyStatsMutex.RLock()
defer s.dailyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.dailyStats {
result[k] = v
}
return result
}
// GetMonthlyStats 获取每月统计数据
func (s *Server) GetMonthlyStats() map[string]int64 {
s.monthlyStatsMutex.RLock()
defer s.monthlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.monthlyStats {
result[k] = v
}
return result
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 检查IPv4内网地址
if ipv4 := parsedIP.To4(); ipv4 != nil {
// 10.0.0.0/8
if ipv4[0] == 10 {
return true
}
// 172.16.0.0/12
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
return true
}
// 192.168.0.0/16
if ipv4[0] == 192 && ipv4[1] == 168 {
return true
}
// 127.0.0.0/8 (localhost)
if ipv4[0] == 127 {
return true
}
// 169.254.0.0/16 (链路本地地址)
if ipv4[0] == 169 && ipv4[1] == 254 {
return true
}
return false
}
// 检查IPv6内网地址
// ::1/128 (localhost)
if parsedIP.IsLoopback() {
return true
}
// fc00::/7 (唯一本地地址)
if parsedIP[0]&0xfc == 0xfc {
return true
}
// fe80::/10 (链路本地地址)
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
return true
}
return false
}
// getIpGeolocation 获取IP地址的地理位置信息
func (s *Server) getIpGeolocation(ip string) string {
// 检查IP是否为本地或内网地址
if isPrivateIP(ip) {
return "内网 内网"
}
// 先检查缓存
s.ipGeolocationCacheMutex.RLock()
geo, exists := s.ipGeolocationCache[ip]
s.ipGeolocationCacheMutex.RUnlock()
// 如果缓存存在且未过期,直接返回
if exists && time.Now().Before(geo.Expiry) {
return fmt.Sprintf("%s %s", geo.Country, geo.City)
}
// 缓存不存在或已过期从API获取
geoInfo, err := s.fetchIpGeolocationFromAPI(ip)
if err != nil {
logger.Error("获取IP地理位置失败", "ip", ip, "error", err)
return "未知 未知"
}
// 保存到缓存
s.ipGeolocationCacheMutex.Lock()
s.ipGeolocationCache[ip] = &IPGeolocation{
Country: geoInfo["country"].(string),
City: geoInfo["city"].(string),
Expiry: time.Now().Add(s.ipGeolocationCacheTTL),
}
s.ipGeolocationCacheMutex.Unlock()
// 返回格式化的地理位置
return fmt.Sprintf("%s %s", geoInfo["country"].(string), geoInfo["city"].(string))
}
// fetchIpGeolocationFromAPI 从第三方API获取IP地理位置信息
func (s *Server) fetchIpGeolocationFromAPI(ip string) (map[string]interface{}, error) {
// 使用ip-api.com获取IP地理位置信息
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=country,city", ip)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 读取响应内容
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解析JSON响应
var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}
// 检查API返回状态
status, ok := result["status"].(string)
if !ok || status != "success" {
return nil, fmt.Errorf("API返回错误状态: %v", result)
}
// 确保国家和城市字段存在
if _, ok := result["country"]; !ok {
result["country"] = "未知"
}
if _, ok := result["city"]; !ok {
result["city"] = "未知"
}
return result, nil
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
if s.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(s.config.StatsFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
s.stats = statsData.Stats
// 确保使用当前配置中的EnableDNSSEC值覆盖从文件加载的值
s.stats.DNSSECEnabled = s.config.EnableDNSSEC
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
s.dailyStatsMutex.Lock()
if statsData.DailyStats != nil {
s.dailyStats = statsData.DailyStats
}
s.dailyStatsMutex.Unlock()
s.monthlyStatsMutex.Lock()
if statsData.MonthlyStats != nil {
s.monthlyStats = statsData.MonthlyStats
}
s.monthlyStatsMutex.Unlock()
// 加载客户端统计数据
s.clientStatsMutex.Lock()
if statsData.ClientStats != nil {
s.clientStats = statsData.ClientStats
}
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
// 加载查询日志
s.loadQueryLogs()
}
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径
statsFilePath, err := filepath.Abs(s.config.StatsFile)
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
return
}
// 构建查询日志文件路径
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
// 检查文件是否存在
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
return
}
// 读取文件内容
data, err := ioutil.ReadFile(queryLogPath)
if err != nil {
logger.Error("读取查询日志文件失败", "error", err)
return
}
// 解析数据
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
logger.Error("解析查询日志失败", "error", err)
return
}
// 更新查询日志
s.queryLogsMutex.Lock()
s.queryLogs = logs
// 确保日志数量不超过限制
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
s.queryLogsMutex.Unlock()
logger.Info("查询日志加载成功", "count", len(logs))
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
if s.config.StatsFile == "" {
return
}
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs(s.config.StatsFile)
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", s.config.StatsFile, "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
s.dailyStatsMutex.RLock()
statsData.DailyStats = make(map[string]int64)
for k, v := range s.dailyStats {
statsData.DailyStats[k] = v
}
s.dailyStatsMutex.RUnlock()
s.monthlyStatsMutex.RLock()
statsData.MonthlyStats = make(map[string]int64)
for k, v := range s.monthlyStats {
statsData.MonthlyStats[k] = v
}
s.monthlyStatsMutex.RUnlock()
// 复制客户端统计数据
s.clientStatsMutex.RLock()
statsData.ClientStats = make(map[string]*ClientStats)
for k, v := range s.clientStats {
statsData.ClientStats[k] = v
}
s.clientStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(statsFilePath, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
return
}
logger.Info("统计数据保存成功", "file", statsFilePath)
// 保存查询日志到文件
s.saveQueryLogs(statsDir)
}
// saveQueryLogs 保存查询日志到文件
func (s *Server) saveQueryLogs(dataDir string) {
// 构建查询日志文件路径
queryLogPath := filepath.Join(dataDir, "querylog.json")
// 获取查询日志数据
s.queryLogsMutex.RLock()
logsCopy := make([]QueryLog, len(s.queryLogs))
copy(logsCopy, s.queryLogs)
s.queryLogsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
if err != nil {
logger.Error("序列化查询日志失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(queryLogPath, jsonData, 0644)
if err != nil {
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
return
}
logger.Info("查询日志保存成功", "file", queryLogPath)
}
// startCpuUsageMonitor 启动CPU使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()
// 初始化
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// 存储上一次的CPU时间统计
var prevIdle, prevTotal uint64
for {
select {
case <-ticker.C:
// 获取真实的系统级CPU使用率
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
if err != nil {
// 如果获取失败,使用默认值
cpuUsage = 0.0
logger.Error("获取系统CPU使用率失败", "error", err)
}
s.updateStats(func(stats *Stats) {
stats.CpuUsage = cpuUsage
})
case <-s.ctx.Done():
return
}
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息
file, err := os.Open("/proc/stat")
if err != nil {
return 0, err
}
defer file.Close()
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
if err != nil {
return 0, err
}
// 计算总的CPU时间
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
idle := cpuIdle + cpuIowait
// 第一次调用时,只初始化值,不计算使用率
if *prevTotal == 0 || *prevIdle == 0 {
*prevIdle = idle
*prevTotal = total
return 0, nil
}
// 计算CPU使用率
idleDelta := idle - *prevIdle
totalDelta := total - *prevTotal
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
// 更新上一次的值
*prevIdle = idle
*prevTotal = total
return utilization, nil
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
return
}
// 设置定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}