Files
dns-server/dns/server.go
2026-01-03 01:11:42 +08:00

3003 lines
84 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/logger"
"dns-server/shield"
"github.com/miekg/dns"
)
// 确保DNS服务器地址包含端口号默认添加53端口
func normalizeDNSServerAddress(address string) string {
// 检查地址是否已经包含端口号
if _, _, err := net.SplitHostPort(address); err != nil {
// 如果没有端口号添加默认的53端口
return net.JoinHostPort(address, "53")
}
// 已经有端口号,直接返回
return address
}
// BlockedDomain 屏蔽域名统计
type BlockedDomain struct {
Domain string
Count int64
LastSeen time.Time
DNSSEC bool // 是否使用了DNSSEC
}
// ClientStats 客户端统计
type ClientStats struct {
IP string
Count int64
LastSeen time.Time
}
// DNSAnswer DNS解析记录
type DNSAnswer struct {
Type string `json:"type"` // 记录类型
Value string `json:"value"` // 记录值
TTL uint32 `json:"ttl"` // 生存时间
}
// QueryLog 查询日志记录
type QueryLog struct {
Timestamp time.Time `json:"timestamp"` // 查询时间
ClientIP string `json:"clientIP"` // 客户端IP
Location string `json:"location"` // IP地理位置国家 城市)
Domain string `json:"domain"` // 查询域名
QueryType string `json:"queryType"` // 查询类型
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
Result string `json:"result"` // 查询结果allowed, blocked, error
BlockRule string `json:"blockRule"` // 屏蔽规则(如果被屏蔽)
BlockType string `json:"blockType"` // 屏蔽类型(如果被屏蔽)
FromCache bool `json:"fromCache"` // 是否来自缓存
DNSSEC bool `json:"dnssec"` // 是否使用了DNSSEC
EDNS bool `json:"edns"` // 是否使用了EDNS
DNSServer string `json:"dnsServer"` // 使用的DNS服务器
DNSSECServer string `json:"dnssecServer"` // 使用的DNSSEC专用服务器
Answers []DNSAnswer `json:"answers"` // 解析记录
ResponseCode int `json:"responseCode"` // DNS响应代码
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
ClientStats map[string]*ClientStats `json:"clientStats"`
HourlyStats map[string]int64 `json:"hourlyStats"`
DailyStats map[string]int64 `json:"dailyStats"`
MonthlyStats map[string]int64 `json:"monthlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// ServerStats 服务器统计信息
type ServerStats struct {
SuccessCount int64 // 成功查询次数
FailureCount int64 // 失败查询次数
LastResponse time.Time // 最后响应时间
ResponseTime time.Duration // 平均响应时间
ConnectionSpeed time.Duration // TCP连接速度
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
shieldConfig *config.ShieldConfig
shieldManager *shield.ShieldManager
server *dns.Server
tcpServer *dns.Server
resolver *dns.Client
ctx context.Context
cancel context.CancelFunc
statsMutex sync.Mutex
stats *Stats
blockedDomainsMutex sync.RWMutex
blockedDomains map[string]*BlockedDomain
resolvedDomainsMutex sync.RWMutex
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
clientStatsMutex sync.RWMutex
clientStats map[string]*ClientStats // 用于记录客户端统计
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
dailyStatsMutex sync.RWMutex
dailyStats map[string]int64 // 按天统计屏蔽数量
monthlyStatsMutex sync.RWMutex
monthlyStats map[string]int64 // 按月统计屏蔽数量
queryLogsMutex sync.RWMutex
queryLogs []QueryLog // 查询日志列表
maxQueryLogs int // 最大保存日志数量
saveTicker *time.Ticker // 用于定时保存数据
startTime time.Time // 服务器启动时间
saveDone chan struct{} // 用于通知保存协程停止
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
// 域名DNSSEC状态映射表
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
// 上游服务器状态跟踪
serverStats map[string]*ServerStats // 服务器地址到状态的映射
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
}
// Stats DNS服务器统计信息
type Stats struct {
Queries int64
Blocked int64
Allowed int64
Errors int64
LastQuery time.Time
AvgResponseTime float64 // 平均响应时间(ms)
TotalResponseTime int64 // 总响应时间
QueryTypes map[string]int64 // 查询类型统计
SourceIPs map[string]bool // 活跃来源IP
CpuUsage float64 // CPU使用率(%)
DNSSECQueries int64 // DNSSEC查询总数
DNSSECSuccess int64 // DNSSEC验证成功数
DNSSECFailed int64 // DNSSEC验证失败数
DNSSECEnabled bool // 是否启用了DNSSEC
}
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值分钟
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
resolver: &dns.Client{
Net: "udp",
UDPSize: 4096, // 增加UDP缓冲区大小支持更大的DNSSEC响应
},
ctx: ctx,
cancel: cancel,
startTime: time.Now(), // 记录服务器启动时间
stats: &Stats{
Queries: 0,
Blocked: 0,
Allowed: 0,
Errors: 0,
AvgResponseTime: 0,
TotalResponseTime: 0,
QueryTypes: make(map[string]int64),
SourceIPs: make(map[string]bool),
CpuUsage: 0,
DNSSECQueries: 0,
DNSSECSuccess: 0,
DNSSECFailed: 0,
DNSSECEnabled: config.EnableDNSSEC,
},
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
clientStats: make(map[string]*ClientStats),
hourlyStats: make(map[string]int64),
dailyStats: make(map[string]int64),
monthlyStats: make(map[string]int64),
queryLogs: make([]QueryLog, 0, 1000), // 初始化查询日志切片容量1000
maxQueryLogs: 10000, // 最大保存10000条日志
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// DNS查询缓存初始化
DnsCache: NewDNSCache(cacheTTL),
// 初始化域名DNSSEC状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// Start 启动DNS服务器
func (s *Server) Start() error {
// 重新初始化上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
s.cancel = cancel
// 重新初始化saveDone通道
s.saveDone = make(chan struct{})
// 重置stopped标志
s.stoppedMutex.Lock()
s.stopped = false
s.stoppedMutex.Unlock()
// 更新服务器启动时间
s.startTime = time.Now()
s.server = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "udp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 保存TCP服务器实例以便在Stop方法中关闭
s.tcpServer = &dns.Server{
Addr: fmt.Sprintf("0.0.0.0:%d", s.config.Port),
Net: "tcp",
Handler: dns.HandlerFunc(s.handleDNSRequest),
}
// 启动CPU使用率监控
go s.startCpuUsageMonitor()
// 启动自动保存功能
go s.startAutoSave()
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
if err := s.server.ListenAndServe(); err != nil {
logger.Error("DNS UDP服务器启动失败", "error", err)
s.cancel()
}
}()
// 启动TCP服务
go func() {
logger.Info(fmt.Sprintf("DNS TCP服务器启动监听端口: %d", s.config.Port))
if err := s.tcpServer.ListenAndServe(); err != nil {
logger.Error("DNS TCP服务器启动失败", "error", err)
s.cancel()
}
}()
// 等待停止信号
<-s.ctx.Done()
return nil
}
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 检查服务器是否已经停止
s.stoppedMutex.Lock()
if s.stopped {
s.stoppedMutex.Unlock()
return // 服务器已经停止,直接返回
}
// 标记服务器为已停止状态
s.stopped = true
s.stoppedMutex.Unlock()
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
if s.tcpServer != nil {
s.tcpServer.Shutdown()
}
logger.Info("DNS服务器已停止")
}
// handleDNSRequest 处理DNS请求
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
// 获取来源IP
sourceIP := w.RemoteAddr().String()
// 提取IP地址部分去掉端口
if strings.HasPrefix(sourceIP, "[") {
// IPv6地址格式: [::1]:53
if idx := strings.Index(sourceIP, "]"); idx >= 0 {
sourceIP = sourceIP[1:idx] // 去掉方括号
}
} else {
// IPv4地址格式: 127.0.0.1:53
if idx := strings.LastIndex(sourceIP, ":"); idx >= 0 {
sourceIP = sourceIP[:idx]
}
}
// 更新来源IP统计
s.updateStats(func(stats *Stats) {
stats.Queries++
stats.LastQuery = time.Now()
stats.SourceIPs[sourceIP] = true
})
// 更新客户端统计
s.updateClientStats(sourceIP)
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
})
// 检查是否是AAAA记录查询且IPv6解析已禁用
if qType == dns.TypeAAAA && !s.config.EnableIPv6 {
// 返回NXDOMAIN响应域名不存在
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeNameError)
w.WriteMsg(response)
// 更新统计信息
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeNameError)
logger.Debug("IPv6解析已禁用拒绝AAAA记录查询", "domain", domain)
return
}
}
logger.Debug("接收到DNS查询", "domain", domain, "type", queryType, "client", w.RemoteAddr())
// 只处理递归查询
if r.RecursionDesired == false {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
response.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(response)
// 缓存命中响应时间设为0ms
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false, false, true, "", "", nil, dns.RcodeRefused)
return
}
// 检查hosts文件是否有匹配
if ip, exists := s.shieldManager.GetHostsIP(domain); exists {
s.handleHostsResponse(w, r, ip)
// 缓存命中响应时间设为0ms
responseTime := int64(0)
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 该方法内部未直接调用addQueryLog而是在handleDNSRequest中处理
return
}
// 检查是否被屏蔽
if s.shieldManager.IsBlocked(domain) {
// 获取屏蔽详情
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
blockRule, _ := blockDetails["blockRule"].(string)
blockType, _ := blockDetails["blockRuleType"].(string)
s.handleBlockedResponse(w, r, domain)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志 - 被屏蔽域名
blockedAnswers := []DNSAnswer{}
// 根据屏蔽方法确定响应代码
blockedRcode := dns.RcodeNameError // 默认NXDOMAIN
if blockMethod := s.shieldConfig.BlockMethod; blockMethod == "refused" {
blockedRcode = dns.RcodeRefused
} else if blockMethod == "emptyIP" || blockMethod == "customIP" {
blockedRcode = dns.RcodeSuccess
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false, false, true, "无", "无", blockedAnswers, blockedRcode)
return
}
// 检查缓存中是否有响应优先查找带DNSSEC的缓存项
var cachedResponse *dns.Msg
var found bool
var cachedDNSSEC bool
// 1. 首先检查是否有普通缓存项
if tempResponse, tempFound := s.DnsCache.Get(r.Question[0].Name, qType); tempFound {
cachedResponse = tempResponse
found = tempFound
cachedDNSSEC = s.hasDNSSECRecords(tempResponse)
}
// 2. 如果启用了DNSSEC且没有找到带DNSSEC的缓存项
// 尝试从所有缓存中查找是否有其他响应包含DNSSEC记录
// 这里可以进一步优化比如在缓存中标记DNSSEC状态快速查找
if s.config.EnableDNSSEC && !cachedDNSSEC {
// 目前的缓存实现不支持按DNSSEC状态查找所以这里暂时跳过
// 后续可以考虑改进缓存实现添加DNSSEC状态标记
}
if found {
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := cachedResponseCopy.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range cachedResponseCopy.Extra {
if cachedResponseCopy.Extra[i] == respOpt {
cachedResponseCopy.Extra = append(cachedResponseCopy.Extra[:i], cachedResponseCopy.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
cachedResponseCopy.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
w.WriteMsg(cachedResponseCopy)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 如果缓存响应包含DNSSEC记录更新DNSSEC查询计数
if cachedDNSSEC {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
// 缓存响应视为DNSSEC成功
stats.DNSSECSuccess++
})
}
// 从缓存响应中提取解析记录
cachedAnswers := []DNSAnswer{}
if cachedResponse != nil {
for _, rr := range cachedResponse.Answer {
cachedAnswers = append(cachedAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为缓存
// 从缓存响应中获取响应代码
cacheRcode := dns.RcodeSuccess // 默认成功
if cachedResponse != nil {
cacheRcode = cachedResponse.Rcode
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true, cachedDNSSEC, true, "缓存", "无", cachedAnswers, cacheRcode)
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType, "dnssec", cachedDNSSEC)
return
}
// 缓存未命中处理DNS请求
var response *dns.Msg
var rtt time.Duration
var queryAttempts []string
var dnsServer string
var dnssecServer string
// 直接查询原始域名
queryAttempts = append(queryAttempts, domain)
response, rtt, dnsServer, dnssecServer = s.forwardDNSRequestWithCache(r, domain)
if response != nil {
// 如果客户端请求包含EDNS记录确保响应也包含EDNS
if opt := r.IsEdns0(); opt != nil {
// 检查响应是否已经包含EDNS记录
if respOpt := response.IsEdns0(); respOpt == nil {
// 添加EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
} else {
// 确保响应的UDP缓冲区大小不超过客户端请求的大小
if respOpt.UDPSize() > opt.UDPSize() {
// 移除现有的EDNS记录
for i := range response.Extra {
if response.Extra[i] == respOpt {
response.Extra = append(response.Extra[:i], response.Extra[i+1:]...)
break
}
}
// 添加新的EDNS记录使用客户端的UDP缓冲区大小
response.SetEdns0(opt.UDPSize(), s.config.EnableDNSSEC)
}
}
}
// 写入响应给客户端
w.WriteMsg(response)
}
// 使用上游服务器的实际响应时间(转换为毫秒)
responseTime := int64(rtt.Milliseconds())
// 如果rtt为0查询失败则使用本地计算的时间
if responseTime == 0 {
responseTime = time.Since(startTime).Milliseconds()
}
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 检查响应是否包含DNSSEC记录并验证结果
responseDNSSEC := false
if response != nil {
// 使用hasDNSSECRecords函数检查是否包含DNSSEC记录
responseDNSSEC = s.hasDNSSECRecords(response)
// 检查AD标志确认DNSSEC验证是否成功
if response.AuthenticatedData {
responseDNSSEC = true
}
// 更新域名的DNSSEC状态
if responseDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
}
}
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
s.DnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL, "dnssec", responseDNSSEC)
}
// 从响应中提取解析记录
responseAnswers := []DNSAnswer{}
if response != nil {
for _, rr := range response.Answer {
responseAnswers = append(responseAnswers, DNSAnswer{
Type: dns.TypeToString[rr.Header().Rrtype],
Value: rr.String(),
TTL: rr.Header().Ttl,
})
}
}
// 添加查询日志 - 标记为实时
// 从响应中获取响应代码
realRcode := dns.RcodeSuccess // 默认成功
if response != nil {
realRcode = response.Rcode
}
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false, responseDNSSEC, true, dnsServer, dnssecServer, responseAnswers, realRcode)
}
// handleHostsResponse 处理hosts文件匹配的响应
func (s *Server) handleHostsResponse(w dns.ResponseWriter, r *dns.Msg, ip string) {
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
if len(r.Question) > 0 {
q := r.Question[0]
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(ip)
response.Answer = append(response.Answer, answer)
}
// 记录解析域名统计
domain := ""
if len(r.Question) > 0 {
domain = r.Question[0].Name
if len(domain) > 0 && domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
s.updateResolvedDomainStats(domain)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
}
// handleBlockedResponse 处理被屏蔽的域名响应
func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain string) {
logger.Info("域名被屏蔽", "domain", domain, "client", w.RemoteAddr())
// 更新被屏蔽域名统计
s.updateBlockedDomainStats(domain)
response := new(dns.Msg)
response.SetReply(r)
// 不再硬编码RecursionAvailable使用默认值或上游返回的值
// 获取屏蔽方法配置
blockMethod := "NXDOMAIN" // 默认值
customBlockIP := "" // 默认值
// 从Server结构体的shieldConfig字段获取配置
if s.shieldConfig != nil {
blockMethod = s.shieldConfig.BlockMethod
customBlockIP = s.shieldConfig.CustomBlockIP
}
// 根据屏蔽方法返回不同的响应
switch blockMethod {
case "refused":
// 返回拒绝查询响应
response.SetRcode(r, dns.RcodeRefused)
case "emptyIP":
// 返回空IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP("0.0.0.0") // 空IP
response.Answer = append(response.Answer, answer)
}
case "customIP":
// 返回自定义IP响应
if len(r.Question) > 0 && r.Question[0].Qtype == dns.TypeA && customBlockIP != "" {
answer := new(dns.A)
answer.Hdr = dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
}
answer.A = net.ParseIP(customBlockIP)
response.Answer = append(response.Answer, answer)
}
case "NXDOMAIN", "":
fallthrough // 默认使用NXDOMAIN
default:
// 返回NXDOMAIN响应域名不存在
response.SetRcode(r, dns.RcodeNameError)
}
w.WriteMsg(response)
s.updateStats(func(stats *Stats) {
stats.Blocked++
})
}
// forwardDNSRequest 转发DNS请求到上游服务器
// serverResponse 用于存储服务器响应的结构体
type serverResponse struct {
response *dns.Msg
rtt time.Duration
server string
error error
}
// recordKey 用于唯一标识DNS记录的结构体
type recordKey struct {
name string
rtype uint16
class uint16
data string
}
// getRecordKey 获取DNS记录的唯一标识
func getRecordKey(rr dns.RR) recordKey {
// 对于同一域名的同一类型记录只保留一个选择最长TTL
// 所以对于A、AAAA、CNAME等记录只使用name、rtype、class作为键
// 对于MX记录还需要考虑Preference字段
// 对于TXT记录需要考虑实际文本内容
// 对于NS记录需要考虑目标服务器
switch rr.Header().Rrtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypePTR:
// 对于A、AAAA、CNAME、PTR记录同一域名只保留一个
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: "",
}
case dns.TypeMX:
// 对于MX记录同一域名的同一Preference只保留一个
if mx, ok := rr.(*dns.MX); ok {
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: fmt.Sprintf("%d", mx.Preference),
}
}
case dns.TypeTXT:
// 对于TXT记录需要考虑实际文本内容
if txt, ok := rr.(*dns.TXT); ok {
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: strings.Join(txt.Txt, " "),
}
}
case dns.TypeNS:
// 对于NS记录需要考虑目标服务器
if ns, ok := rr.(*dns.NS); ok {
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: ns.Ns,
}
}
case dns.TypeSOA:
// 对于SOA记录同一域名只保留一个
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: "",
}
}
// 对于其他类型使用原始rr.String()但移除TTL部分
parts := strings.Split(rr.String(), " ")
if len(parts) >= 5 {
// 跳过TTL字段第3个字段
data := strings.Join(append(parts[:2], parts[3:]...), " ")
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: data,
}
}
return recordKey{
name: rr.Header().Name,
rtype: rr.Header().Rrtype,
class: rr.Header().Class,
data: rr.String(),
}
}
// mergeResponses 合并多个DNS响应
func mergeResponses(responses []*dns.Msg) *dns.Msg {
if len(responses) == 0 {
return nil
}
// 如果只有一个响应,直接返回,避免不必要的合并操作
if len(responses) == 1 {
return responses[0].Copy()
}
// 使用第一个响应作为基础
mergedResponse := responses[0].Copy()
mergedResponse.Answer = []dns.RR{}
mergedResponse.Ns = []dns.RR{}
mergedResponse.Extra = []dns.RR{}
// 重置Rcode为成功除非所有响应都是NXDOMAIN
mergedResponse.Rcode = dns.RcodeSuccess
// 检查是否所有响应都是NXDOMAIN
allNXDOMAIN := true
// 收集所有成功响应的记录
for _, resp := range responses {
if resp == nil {
continue
}
// 如果有任何响应是成功的就不是allNXDOMAIN
if resp.Rcode == dns.RcodeSuccess {
allNXDOMAIN = false
}
}
// 如果所有响应都是NXDOMAIN设置合并响应为NXDOMAIN
if allNXDOMAIN {
mergedResponse.Rcode = dns.RcodeNameError
}
// 使用map存储唯一记录选择最长TTL
// 预分配map容量减少扩容开销
answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses))
nsMap := make(map[recordKey]dns.RR, len(responses[0].Ns)*len(responses))
extraMap := make(map[recordKey]dns.RR, len(responses[0].Extra)*len(responses))
for _, resp := range responses {
if resp == nil {
continue
}
// 只合并与最终Rcode匹配的响应记录
if (mergedResponse.Rcode == dns.RcodeSuccess && resp.Rcode == dns.RcodeSuccess) ||
(mergedResponse.Rcode == dns.RcodeNameError && resp.Rcode == dns.RcodeNameError) {
// 合并Answer部分
for _, rr := range resp.Answer {
key := getRecordKey(rr)
if existing, exists := answerMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
answerMap[key] = rr
}
} else {
answerMap[key] = rr
}
}
// 合并Ns部分
for _, rr := range resp.Ns {
key := getRecordKey(rr)
if existing, exists := nsMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
nsMap[key] = rr
}
} else {
nsMap[key] = rr
}
}
// 合并Extra部分
for _, rr := range resp.Extra {
// 跳过OPT记录避免重复
if rr.Header().Rrtype == dns.TypeOPT {
continue
}
key := getRecordKey(rr)
if existing, exists := extraMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
extraMap[key] = rr
}
} else {
extraMap[key] = rr
}
}
}
}
// 预分配切片容量,减少扩容开销
mergedResponse.Answer = make([]dns.RR, 0, len(answerMap))
mergedResponse.Ns = make([]dns.RR, 0, len(nsMap))
mergedResponse.Extra = make([]dns.RR, 0, len(extraMap))
// 将map转换回切片
for _, rr := range answerMap {
mergedResponse.Answer = append(mergedResponse.Answer, rr)
}
for _, rr := range nsMap {
mergedResponse.Ns = append(mergedResponse.Ns, rr)
}
for _, rr := range extraMap {
mergedResponse.Extra = append(mergedResponse.Extra, rr)
}
return mergedResponse
}
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) {
// 始终支持EDNS
var udpSize uint16 = 4096
var doFlag bool = s.config.EnableDNSSEC
// 检查域名是否匹配不验证DNSSEC的模式
noDNSSEC := false
for _, pattern := range s.config.NoDNSSECDomains {
if strings.Contains(domain, pattern) {
noDNSSEC = true
doFlag = false
logger.Debug("域名匹配到不验证DNSSEC的模式", "domain", domain, "pattern", pattern)
break
}
}
// 检查客户端请求是否包含EDNS记录
if opt := r.IsEdns0(); opt != nil {
// 保留客户端的UDP缓冲区大小
udpSize = opt.UDPSize()
// 移除现有的EDNS记录以便重新添加
for i := range r.Extra {
if r.Extra[i] == opt {
r.Extra = append(r.Extra[:i], r.Extra[i+1:]...)
break
}
}
}
// 添加EDNS记录设置适当的UDPSize和DO标志
r.SetEdns0(udpSize, doFlag)
// DNSSEC专用服务器列表从配置中获取
dnssecServers := s.config.DNSSECUpstreamDNS
// 选择合适的上游DNS服务器列表
// 1. 首先检查是否有域名特定的DNS服务器配置
var selectedUpstreamDNS []string
var domainMatched bool
for matchStr, dnsServers := range s.config.DomainSpecificDNS {
if strings.Contains(domain, matchStr) {
selectedUpstreamDNS = dnsServers
domainMatched = true
logger.Debug("域名匹配到特定DNS服务器配置", "domain", domain, "matchStr", matchStr, "dnsServers", dnsServers)
break
}
}
// 2. 如果没有匹配的域名特定配置
if !domainMatched {
// 创建一个新的切片来存储最终的上游服务器列表
var finalUpstreamDNS []string
// 首先添加用户配置的上游DNS服务器
finalUpstreamDNS = append(finalUpstreamDNS, s.config.UpstreamDNS...)
logger.Debug("使用用户配置的上游DNS服务器", "servers", finalUpstreamDNS)
// 如果启用了DNSSEC且有配置DNSSEC专用服务器并且域名不匹配NoDNSSECDomains则将DNSSEC专用服务器添加到列表中
if s.config.EnableDNSSEC && len(s.config.DNSSECUpstreamDNS) > 0 && !noDNSSEC {
// 合并DNSSEC专用服务器到上游服务器列表避免重复并确保包含端口号
for _, dnssecServer := range s.config.DNSSECUpstreamDNS {
hasDuplicate := false
// 确保DNSSEC服务器地址包含端口号
normalizedDnssecServer := normalizeDNSServerAddress(dnssecServer)
for _, upstream := range finalUpstreamDNS {
if upstream == normalizedDnssecServer {
hasDuplicate = true
break
}
}
if !hasDuplicate {
finalUpstreamDNS = append(finalUpstreamDNS, normalizedDnssecServer)
}
}
logger.Debug("合并DNSSEC专用服务器到上游服务器列表", "servers", finalUpstreamDNS)
}
// 使用最终合并后的服务器列表
selectedUpstreamDNS = finalUpstreamDNS
}
// 1. 首先尝试所有配置的上游DNS服务器
var bestResponse *dns.Msg
var bestRtt time.Duration
var hasBestResponse bool
var hasDNSSECResponse bool
var backupResponse *dns.Msg
var backupRtt time.Duration
var hasBackup bool
var usedDNSServer string
var usedDNSSECServer string
// 使用配置中的超时时间
defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond
// 根据查询模式处理请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式 - 收集所有响应并合并
responses := make(chan serverResponse, len(selectedUpstreamDNS))
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求,每个请求带有超时
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
// 等待所有请求完成或超时
go func() {
wg.Wait()
close(responses)
}()
// 收集成功响应和NXDOMAIN响应分开
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
var totalRtt time.Duration
var responseCount int
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
// 只对将要返回的响应进行DNSSEC验证减少开销
// 这里只设置containsDNSSEC标志实际验证在确定返回响应后进行
if containsDNSSEC && s.config.EnableDNSSEC && !noDNSSEC {
// 暂时不验证,只标记
}
// 检查当前服务器是否是DNSSEC专用服务器
for _, dnssecServer := range dnssecServers {
if dnssecServer == resp.server {
usedDNSSECServer = resp.server
break
}
}
// 收集响应按Rcode分类
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
totalRtt += resp.rtt
responseCount++
// 记录使用的服务器
if usedDNSServer == "" {
usedDNSServer = resp.server
}
} else if resp.response.Rcode == dns.RcodeNameError {
nxdomainResponses = append(nxdomainResponses, resp.response)
} else {
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
if !hasBackup {
// 第一次保存备选响应
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
// 合并响应优先使用成功响应只有当没有成功响应时才使用NXDOMAIN响应
var validResponses []*dns.Msg
if len(successResponses) > 0 {
validResponses = successResponses
} else {
validResponses = nxdomainResponses
}
// 合并所有有效响应
if len(validResponses) > 0 {
bestResponse = mergeResponses(validResponses)
if responseCount > 0 {
bestRtt = totalRtt / time.Duration(responseCount)
}
hasBestResponse = true
// 设置日志的type字段
logType := "success"
if len(successResponses) == 0 {
logType = "nxdomain"
}
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses), "type", logType)
}
case "fastest-ip":
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
// 1. 选择最快的服务器
fastestServer := s.selectFastestServer(selectedUpstreamDNS)
if fastestServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(fastestServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{resp, r, e}
}()
var response *dns.Msg
var rtt time.Duration
var err error
// 直接获取结果,不使用上下文超时
result := <-resultChan
response, rtt, err = result.response, result.rtt, result.err
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
// 但如果域名匹配不验证DNSSEC的模式则跳过验证
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
} else if noDNSSEC {
// 对于不验证DNSSEC的域名始终设置AD标志为false
response.AuthenticatedData = false
}
// 如果响应成功或为NXDOMAIN根据DNSSEC状态选择最佳响应
if response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError {
if response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC记录的响应
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
usedDNSServer = fastestServer
// 如果当前使用的服务器是DNSSEC专用服务器同时设置usedDNSSECServer
for _, dnssecServer := range dnssecServers {
if dnssecServer == fastestServer {
usedDNSSECServer = fastestServer
break
}
}
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
} else {
// 没有带DNSSEC的响应时保存成功响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
usedDNSServer = fastestServer
// 如果当前使用的服务器是DNSSEC专用服务器同时设置usedDNSSECServer
for _, dnssecServer := range dnssecServers {
if dnssecServer == fastestServer {
usedDNSSECServer = fastestServer
break
}
}
logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
}
} else if response.Rcode == dns.RcodeNameError {
// 处理NXDOMAIN响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
usedDNSServer = fastestServer
logger.Debug("找到NXDOMAIN响应", "domain", domain, "server", fastestServer, "rtt", rtt)
}
// 保存为备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
}
}
default:
// 默认使用并行请求模式 - 实现快速返回和超时机制
responses := make(chan serverResponse, len(selectedUpstreamDNS))
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}, 1)
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
// 处理响应的协程
go func() {
var fastestResponse *dns.Msg
var fastestRtt time.Duration = defaultTimeout
var fastestServer string
var fastestDnssecServer string
var fastestHasDnssec bool
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
// 等待所有请求完成或超时
timer := time.NewTimer(defaultTimeout)
defer timer.Stop()
// 处理所有响应
for {
select {
case resp, ok := <-responses:
if !ok {
// 所有响应都已处理
goto doneProcessing
}
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
// 检查当前服务器是否是DNSSEC专用服务器
dnssecServerForResponse := ""
for _, dnssecServer := range dnssecServers {
if dnssecServer == resp.server {
dnssecServerForResponse = resp.server
break
}
}
// 如果响应成功或为NXDOMAIN
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
// 按Rcode分类添加到不同列表
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
} else {
nxdomainResponses = append(nxdomainResponses, resp.response)
}
// 快速返回逻辑:找到第一个有效响应或更快的响应
if resp.response.Rcode == dns.RcodeSuccess {
// 优先选择带有DNSSEC的响应
if containsDNSSEC {
// 如果这是第一个DNSSEC响应或者比当前最快的DNSSEC响应更快
if !fastestHasDnssec || resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
fastestHasDnssec = true
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
} else {
// 非DNSSEC响应只有在还没有找到DNSSEC响应且当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else if resp.response.Rcode == dns.RcodeNameError {
// NXDOMAIN响应只有在还没有找到响应或当前响应更快时才更新
if !fastestHasDnssec && resp.rtt < fastestRtt {
fastestResponse = resp.response
fastestRtt = resp.rtt
fastestServer = resp.server
fastestDnssecServer = dnssecServerForResponse
// 检查是否包含DNSSEC记录
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
// 只对将要返回的响应进行DNSSEC验证
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(fastestResponse)
// 设置AD标志Authenticated Data
fastestResponse.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 发送结果,快速返回
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
}
} else {
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
if !hasBackup {
// 第一次保存备选响应
backupResponse = resp.response
backupRtt = resp.rtt
hasBackup = true
}
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
case <-timer.C:
// 超时,停止等待更多响应
goto doneProcessing
}
}
doneProcessing:
// 合并响应优先使用成功响应只有当没有成功响应时才使用NXDOMAIN响应
var validResponses []*dns.Msg
if len(successResponses) > 0 {
validResponses = successResponses
} else {
validResponses = nxdomainResponses
}
// 合并所有有效响应,用于缓存
if len(validResponses) > 1 {
mergedResponse := mergeResponses(validResponses)
if mergedResponse != nil {
// 只在合并后的响应比最快响应更好时才使用
mergedHasDnssec := s.hasDNSSECRecords(mergedResponse)
if mergedHasDnssec && !fastestHasDnssec {
// 合并后的响应有DNSSEC而最快响应没有使用合并后的响应
fastestResponse = mergedResponse
// 使用最快的Rtt作为合并响应的Rtt
fastestHasDnssec = true
}
}
}
// 如果还没有发送结果,发送最快的响应
if fastestResponse != nil {
resultChan <- struct {
response *dns.Msg
rtt time.Duration
usedServer string
usedDnssecServer string
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
}
close(resultChan)
}()
// 等待所有请求完成(不阻塞主流程)
go func() {
wg.Wait()
close(responses)
}()
// 等待结果或超时
select {
case result := <-resultChan:
// 快速返回结果
bestResponse = result.response
bestRtt = result.rtt
usedDNSServer = result.usedServer
usedDNSSECServer = result.usedDnssecServer
hasBestResponse = true
hasDNSSECResponse = s.hasDNSSECRecords(result.response)
logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse)
case <-time.After(defaultTimeout):
// 超时,使用备选响应
logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout)
}
}
// 2. 当启用DNSSEC且没有找到带DNSSEC的响应时向DNSSEC专用服务器发送请求
// 但如果域名匹配了domainSpecificDNS配置或NoDNSSECDomains则不使用DNSSEC专用服务器只使用指定的DNS服务器
if s.config.EnableDNSSEC && !hasDNSSECResponse && !domainMatched && !noDNSSEC {
logger.Debug("向DNSSEC专用服务器发送请求", "domain", domain)
// 增加DNSSEC查询计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
})
// 无论查询模式是什么DNSSEC验证都只使用加权随机选择一个服务器
selectedDnssecServer := s.selectWeightedRandomServer(dnssecServers)
if selectedDnssecServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{response, rtt, err}
}()
var response *dns.Msg
var rtt time.Duration
var err error
// 使用超时获取结果
select {
case result := <-resultChan:
response, rtt, err = result.response, result.rtt, result.err
case <-time.After(defaultTimeout):
// 超时,不再等待
logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout)
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedDnssecServer, true, rtt)
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 无论响应是否包含DNSSEC记录只要使用了DNSSEC专用服务器就设置usedDNSSECServer
usedDNSSECServer = selectedDnssecServer
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", selectedDnssecServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedDnssecServer, false, 0)
}
}
}
// 3. 返回最佳响应
if hasBestResponse {
// 检查最佳响应是否包含DNSSEC记录
bestHasDNSSEC := s.hasDNSSECRecords(bestResponse)
// 如果启用了DNSSEC且最佳响应不包含DNSSEC记录尝试使用本地解析使用upstreamDNS服务器
// 但如果域名匹配了domainSpecificDNS配置则不执行此逻辑只使用指定的DNS服务器
if s.config.EnableDNSSEC && !bestHasDNSSEC && !domainMatched {
logger.Debug("最佳响应不包含DNSSEC记录尝试使用本地解析upstreamDNS", "domain", domain)
// 选择一个upstreamDNS服务器进行解析使用加权随机算法
localServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
if localServer != "" {
// 使用带超时的方式执行Exchange
resultChan := make(chan struct {
response *dns.Msg
rtt time.Duration
err error
}, 1)
go func() {
resp, r, e := s.resolver.Exchange(r, normalizeDNSServerAddress(localServer))
resultChan <- struct {
response *dns.Msg
rtt time.Duration
err error
}{resp, r, e}
}()
var localResponse *dns.Msg
var rtt time.Duration
var err error
// 直接获取结果,不使用上下文超时
result := <-resultChan
localResponse, rtt, err = result.response, result.rtt, result.err
if err == nil && localResponse != nil {
// 更新服务器统计信息
s.updateServerStats(localServer, true, rtt)
// 检查是否包含DNSSEC记录
localHasDNSSEC := s.hasDNSSECRecords(localResponse)
// 验证DNSSEC记录如果存在但不影响最终响应
if localHasDNSSEC {
signatureValid := s.verifyDNSSEC(localResponse)
localResponse.AuthenticatedData = signatureValid
if signatureValid {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
s.updateDomainDNSSECStatus(domain, localHasDNSSEC)
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
logger.Debug("使用本地解析结果upstreamDNS", "domain", domain, "server", localServer, "rtt", rtt)
return localResponse, rtt, localServer, ""
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(localServer, false, 0)
}
}
}
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新域名的DNSSEC状态
if bestHasDNSSEC {
s.updateDomainDNSSECStatus(domain, true)
} else {
s.updateDomainDNSSECStatus(domain, false)
}
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
}
// 如果有备选响应,返回该响应
if hasBackup {
logger.Debug("使用备选响应,没有找到更好的结果", "domain", domain)
// 记录解析域名统计
s.updateResolvedDomainStats(domain)
// 更新统计信息
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return backupResponse, backupRtt, "", ""
}
// 所有上游服务器都失败,返回服务器失败错误
response := new(dns.Msg)
response.SetReply(r)
response.SetRcode(r, dns.RcodeServerFailure)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0, "", ""
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _, _, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
func (s *Server) updateBlockedDomainStats(domain string) {
// 更新被屏蔽域名计数
s.blockedDomainsMutex.Lock()
defer s.blockedDomainsMutex.Unlock()
if entry, exists := s.blockedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.blockedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
}
}
// 更新统计数据
now := time.Now()
// 更新小时统计
hourKey := now.Format("2006-01-02-15")
s.hourlyStatsMutex.Lock()
s.hourlyStats[hourKey]++
s.hourlyStatsMutex.Unlock()
// 更新每日统计
dayKey := now.Format("2006-01-02")
s.dailyStatsMutex.Lock()
s.dailyStats[dayKey]++
s.dailyStatsMutex.Unlock()
// 更新每月统计
monthKey := now.Format("2006-01")
s.monthlyStatsMutex.Lock()
s.monthlyStats[monthKey]++
s.monthlyStatsMutex.Unlock()
}
// updateClientStats 更新客户端统计
func (s *Server) updateClientStats(ip string) {
s.clientStatsMutex.Lock()
defer s.clientStatsMutex.Unlock()
if entry, exists := s.clientStats[ip]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.clientStats[ip] = &ClientStats{
IP: ip,
Count: 1,
LastSeen: time.Now(),
}
}
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func (s *Server) hasDNSSECRecords(response *dns.Msg) bool {
// 检查响应中是否包含DNSSEC相关记录DNSKEY、RRSIG、DS、NSEC、NSEC3等
for _, rr := range response.Answer {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Ns {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Extra {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
return false
}
// verifyDNSSEC 验证DNSSEC签名
func (s *Server) verifyDNSSEC(response *dns.Msg) bool {
// 提取DNSKEY和RRSIG记录
dnskeys := make(map[uint16]*dns.DNSKEY) // KeyTag -> DNSKEY
rrsigs := make([]*dns.RRSIG, 0)
// 从响应中提取所有DNSKEY和RRSIG记录
for _, rr := range response.Answer {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
for _, rr := range response.Ns {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
for _, rr := range response.Extra {
if dnskey, ok := rr.(*dns.DNSKEY); ok {
tag := dnskey.KeyTag()
dnskeys[tag] = dnskey
} else if rrsig, ok := rr.(*dns.RRSIG); ok {
rrsigs = append(rrsigs, rrsig)
}
}
// 如果没有RRSIG记录验证失败
if len(rrsigs) == 0 {
return false
}
// 验证所有RRSIG记录
signatureValid := true
// 用于记录已经警告过的DNSKEY tag避免重复警告
warnedKeyTags := make(map[uint16]bool)
for _, rrsig := range rrsigs {
// 查找对应的DNSKEY
dnskey, exists := dnskeys[rrsig.KeyTag]
if !exists {
// 仅当该key_tag尚未警告过时才记录警告
if !warnedKeyTags[rrsig.KeyTag] {
logger.Warn("DNSSEC验证失败找不到对应的DNSKEY", "key_tag", rrsig.KeyTag)
warnedKeyTags[rrsig.KeyTag] = true
}
signatureValid = false
continue
}
// 收集需要验证的记录集
rrset := make([]dns.RR, 0)
for _, rr := range response.Answer {
if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered {
rrset = append(rrset, rr)
}
}
for _, rr := range response.Ns {
if rr.Header().Name == rrsig.Header().Name && rr.Header().Rrtype == rrsig.TypeCovered {
rrset = append(rrset, rr)
}
}
// 验证签名
if len(rrset) > 0 {
err := rrsig.Verify(dnskey, rrset)
if err != nil {
logger.Warn("DNSSEC签名验证失败", "error", err, "key_tag", rrsig.KeyTag)
signatureValid = false
} else {
logger.Debug("DNSSEC签名验证成功", "key_tag", rrsig.KeyTag)
}
}
}
return signatureValid
}
// updateDomainDNSSECStatus 更新域名的DNSSEC状态
func (s *Server) updateDomainDNSSECStatus(domain string, dnssec bool) {
// 确保域名是小写
domain = strings.ToLower(domain)
// 更新域名的DNSSEC状态
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
// 更新resolvedDomains中的DNSSEC状态
if entry, exists := s.resolvedDomains[domain]; exists {
entry.DNSSEC = dnssec
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
DNSSEC: dnssec,
}
}
// 更新domainDNSSECStatus映射
s.domainDNSSECStatus[domain] = dnssec
}
// updateResolvedDomainStats 更新解析域名统计
func (s *Server) updateResolvedDomainStats(domain string) {
s.resolvedDomainsMutex.Lock()
defer s.resolvedDomainsMutex.Unlock()
if entry, exists := s.resolvedDomains[domain]; exists {
entry.Count++
entry.LastSeen = time.Now()
} else {
s.resolvedDomains[domain] = &BlockedDomain{
Domain: domain,
Count: 1,
LastSeen: time.Now(),
DNSSEC: false,
}
}
}
// getServerStats 获取服务器统计信息,如果不存在则创建
func (s *Server) getServerStats(server string) *ServerStats {
s.serverStatsMutex.RLock()
stats, exists := s.serverStats[server]
s.serverStatsMutex.RUnlock()
if !exists {
// 创建新的服务器统计信息
stats = &ServerStats{
SuccessCount: 0,
FailureCount: 0,
LastResponse: time.Now(),
ResponseTime: 0,
ConnectionSpeed: 0,
}
// 加锁更新服务器统计信息
s.serverStatsMutex.Lock()
s.serverStats[server] = stats
s.serverStatsMutex.Unlock()
}
return stats
}
// updateServerStats 更新服务器统计信息
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
defer s.serverStatsMutex.Unlock()
// 更新统计信息
stats.LastResponse = time.Now()
if success {
stats.SuccessCount++
// 更新平均响应时间(简单移动平均)
// 将所有值转换为纳秒进行计算然后再转换回Duration
if stats.SuccessCount == 1 {
// 第一次成功,直接使用当前响应时间
stats.ResponseTime = rtt
} else {
// 使用纳秒进行计算以避免类型不匹配
prevTotal := stats.ResponseTime.Nanoseconds() * (stats.SuccessCount - 1)
newTotal := prevTotal + rtt.Nanoseconds()
stats.ResponseTime = time.Duration(newTotal / stats.SuccessCount)
}
} else {
stats.FailureCount++
}
}
// selectWeightedRandomServer 加权随机选择服务器
func (s *Server) selectWeightedRandomServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 计算每个服务器的权重
type serverWeight struct {
server string
weight int64
}
var totalWeight int64
weights := make([]serverWeight, 0, len(servers))
// 获取所有服务器的平均响应时间,用于归一化
var totalResponseTime time.Duration
validServers := 0
for _, server := range servers {
stats := s.getServerStats(server)
if stats.ResponseTime > 0 {
totalResponseTime += stats.ResponseTime
validServers++
}
}
// 计算平均响应时间基准值
var avgResponseTime time.Duration
if validServers > 0 {
avgResponseTime = totalResponseTime / time.Duration(validServers)
} else {
avgResponseTime = 1 * time.Second // 默认基准值
}
for _, server := range servers {
stats := s.getServerStats(server)
// 计算基础权重:成功次数 - 失败次数 * 2失败权重更高
// 确保权重至少为1
baseWeight := stats.SuccessCount - stats.FailureCount*2
if baseWeight < 1 {
baseWeight = 1
}
// 计算响应时间调整因子:响应时间越短,因子越高
// 如果没有响应时间数据使用默认值1
var responseFactor float64 = 1.0
if stats.ResponseTime > 0 {
// 使用平均响应时间作为基准,计算调整因子
// 响应时间越短因子越高最高为2.0最低为0.5
responseFactor = float64(avgResponseTime) / float64(stats.ResponseTime)
// 限制调整因子的范围,避免权重波动过大
if responseFactor > 2.0 {
responseFactor = 2.0
} else if responseFactor < 0.5 {
responseFactor = 0.5
}
}
// 综合计算最终权重,四舍五入到整数
finalWeight := int64(float64(baseWeight) * responseFactor)
// 确保最终权重至少为1
if finalWeight < 1 {
finalWeight = 1
}
weights = append(weights, serverWeight{server, finalWeight})
totalWeight += finalWeight
}
// 随机选择一个权重
random := time.Now().UnixNano() % totalWeight
if random < 0 {
random += totalWeight
}
// 选择对应的服务器
var currentWeight int64
for _, sw := range weights {
currentWeight += sw.weight
if random < currentWeight {
return sw.server
}
}
// 兜底返回第一个服务器
return servers[0]
}
// measureServerSpeed 测量服务器TCP连接速度
func (s *Server) measureServerSpeed(server string) time.Duration {
// 提取服务器地址和端口
addr := server
if !strings.Contains(server, ":") {
addr = server + ":53"
}
// 测量TCP连接时间
startTime := time.Now()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
// 连接失败,返回最大持续时间
return 2 * time.Second
}
defer conn.Close()
// 计算连接建立时间
connTime := time.Since(startTime)
// 更新服务器连接速度
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
// 使用指数移动平均更新连接速度
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
s.serverStatsMutex.Unlock()
return connTime
}
// selectFastestServer 选择连接速度最快的服务器
func (s *Server) selectFastestServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 并行测量所有服务器的速度
type speedResult struct {
server string
speed time.Duration
}
results := make(chan speedResult, len(servers))
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(srv string) {
defer wg.Done()
speed := s.measureServerSpeed(srv)
results <- speedResult{srv, speed}
}(server)
}
// 等待所有测量完成
go func() {
wg.Wait()
close(results)
}()
// 找出最快的服务器
var fastestServer string
var fastestSpeed time.Duration = 2 * time.Second
for result := range results {
if result.speed < fastestSpeed {
fastestSpeed = result.speed
fastestServer = result.server
}
}
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
if fastestServer == "" {
fastestServer = servers[0]
}
return fastestServer
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
update(s.stats)
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache, dnssec, edns bool, dnsServer, dnssecServer string, answers []DNSAnswer, responseCode int) {
// 创建日志记录
log := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
DNSSEC: dnssec,
EDNS: edns,
DNSServer: dnsServer,
DNSSECServer: dnssecServer,
Answers: answers,
ResponseCode: responseCode,
}
// 添加到日志列表
s.queryLogsMutex.Lock()
defer s.queryLogsMutex.Unlock()
// 插入到列表开头
s.queryLogs = append([]QueryLog{log}, s.queryLogs...)
// 限制日志数量
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
}
// GetStartTime 获取服务器启动时间
func (s *Server) GetStartTime() time.Time {
return s.startTime
}
// GetStats 获取DNS服务器统计信息
func (s *Server) GetStats() *Stats {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 复制查询类型统计
queryTypesCopy := make(map[string]int64)
for k, v := range s.stats.QueryTypes {
queryTypesCopy[k] = v
}
// 复制来源IP统计
sourceIPsCopy := make(map[string]bool)
for ip := range s.stats.SourceIPs {
sourceIPsCopy[ip] = true
}
// 返回统计信息的副本
return &Stats{
Queries: s.stats.Queries,
Blocked: s.stats.Blocked,
Allowed: s.stats.Allowed,
Errors: s.stats.Errors,
LastQuery: s.stats.LastQuery,
AvgResponseTime: s.stats.AvgResponseTime,
TotalResponseTime: s.stats.TotalResponseTime,
QueryTypes: queryTypesCopy,
SourceIPs: sourceIPsCopy,
CpuUsage: s.stats.CpuUsage,
DNSSECQueries: s.stats.DNSSECQueries,
DNSSECSuccess: s.stats.DNSSECSuccess,
DNSSECFailed: s.stats.DNSSECFailed,
DNSSECEnabled: s.stats.DNSSECEnabled,
}
}
// GetQueryLogs 获取查询日志
func (s *Server) GetQueryLogs(limit, offset int, sortField, sortDirection, resultFilter, searchTerm string) []QueryLog {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
// 确保偏移量和限制值合理
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 100 // 默认返回100条日志
}
// 创建日志副本用于过滤和排序
var logsCopy []QueryLog
// 先过滤日志
for _, log := range s.queryLogs {
// 应用结果过滤
if resultFilter != "" && log.Result != resultFilter {
continue
}
// 应用搜索过滤
if searchTerm != "" {
// 搜索域名或客户端IP
if !strings.Contains(log.Domain, searchTerm) && !strings.Contains(log.ClientIP, searchTerm) {
continue
}
}
logsCopy = append(logsCopy, log)
}
// 排序日志
if sortField != "" {
sort.Slice(logsCopy, func(i, j int) bool {
var a, b interface{}
switch sortField {
case "time":
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
case "clientIp":
a = logsCopy[i].ClientIP
b = logsCopy[j].ClientIP
case "domain":
a = logsCopy[i].Domain
b = logsCopy[j].Domain
case "responseTime":
a = logsCopy[i].ResponseTime
b = logsCopy[j].ResponseTime
case "blockRule":
a = logsCopy[i].BlockRule
b = logsCopy[j].BlockRule
default:
// 默认按时间排序
a = logsCopy[i].Timestamp
b = logsCopy[j].Timestamp
}
// 根据排序方向比较
if sortDirection == "asc" {
return compareValues(a, b) < 0
}
return compareValues(a, b) > 0
})
}
// 计算返回范围
start := offset
end := offset + limit
if end > len(logsCopy) {
end = len(logsCopy)
}
if start >= len(logsCopy) {
return []QueryLog{} // 没有数据,返回空切片
}
return logsCopy[start:end]
}
// compareValues 比较两个值
func compareValues(a, b interface{}) int {
switch v1 := a.(type) {
case time.Time:
v2 := b.(time.Time)
if v1.Before(v2) {
return -1
}
if v1.After(v2) {
return 1
}
return 0
case string:
v2 := b.(string)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
case int64:
v2 := b.(int64)
if v1 < v2 {
return -1
}
if v1 > v2 {
return 1
}
return 0
default:
return 0
}
}
// GetQueryLogsCount 获取查询日志总数
func (s *Server) GetQueryLogsCount() int {
s.queryLogsMutex.RLock()
defer s.queryLogsMutex.RUnlock()
return len(s.queryLogs)
}
// GetQueryStats 获取查询统计信息
func (s *Server) GetQueryStats() map[string]interface{} {
s.statsMutex.Lock()
defer s.statsMutex.Unlock()
// 计算统计数据
return map[string]interface{}{
"totalQueries": s.stats.Queries,
"blockedQueries": s.stats.Blocked,
"allowedQueries": s.stats.Allowed,
"errorQueries": s.stats.Errors,
"avgResponseTime": s.stats.AvgResponseTime,
"activeIPs": len(s.stats.SourceIPs),
}
}
// GetTopBlockedDomains 获取TOP屏蔽域名列表
func (s *Server) GetTopBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按计数排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopResolvedDomains 获取TOP解析域名
func (s *Server) GetTopResolvedDomains(limit int) []BlockedDomain {
s.resolvedDomainsMutex.RLock()
defer s.resolvedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.resolvedDomains))
for _, entry := range s.resolvedDomains {
domains = append(domains, *entry)
}
// 按数量排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].Count > domains[j].Count
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetRecentBlockedDomains 获取最近屏蔽的域名列表
func (s *Server) GetRecentBlockedDomains(limit int) []BlockedDomain {
s.blockedDomainsMutex.RLock()
defer s.blockedDomainsMutex.RUnlock()
// 转换为切片
domains := make([]BlockedDomain, 0, len(s.blockedDomains))
for _, entry := range s.blockedDomains {
domains = append(domains, *entry)
}
// 按时间排序
sort.Slice(domains, func(i, j int) bool {
return domains[i].LastSeen.After(domains[j].LastSeen)
})
// 返回限制数量
if len(domains) > limit {
return domains[:limit]
}
return domains
}
// GetTopClients 获取TOP客户端列表
func (s *Server) GetTopClients(limit int) []ClientStats {
s.clientStatsMutex.RLock()
defer s.clientStatsMutex.RUnlock()
// 转换为切片
clients := make([]ClientStats, 0, len(s.clientStats))
for _, entry := range s.clientStats {
clients = append(clients, *entry)
}
// 按请求次数排序
sort.Slice(clients, func(i, j int) bool {
return clients[i].Count > clients[j].Count
})
// 返回限制数量
if len(clients) > limit {
return clients[:limit]
}
return clients
}
// GetHourlyStats 获取每小时统计数据
func (s *Server) GetHourlyStats() map[string]int64 {
s.hourlyStatsMutex.RLock()
defer s.hourlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.hourlyStats {
result[k] = v
}
return result
}
// GetDailyStats 获取每日统计数据
func (s *Server) GetDailyStats() map[string]int64 {
s.dailyStatsMutex.RLock()
defer s.dailyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.dailyStats {
result[k] = v
}
return result
}
// GetMonthlyStats 获取每月统计数据
func (s *Server) GetMonthlyStats() map[string]int64 {
s.monthlyStatsMutex.RLock()
defer s.monthlyStatsMutex.RUnlock()
// 返回副本
result := make(map[string]int64)
for k, v := range s.monthlyStats {
result[k] = v
}
return result
}
// isPrivateIP 检测IP地址是否为内网IP
func isPrivateIP(ip string) bool {
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 检查IPv4内网地址
if ipv4 := parsedIP.To4(); ipv4 != nil {
// 10.0.0.0/8
if ipv4[0] == 10 {
return true
}
// 172.16.0.0/12
if ipv4[0] == 172 && (ipv4[1] >= 16 && ipv4[1] <= 31) {
return true
}
// 192.168.0.0/16
if ipv4[0] == 192 && ipv4[1] == 168 {
return true
}
// 127.0.0.0/8 (localhost)
if ipv4[0] == 127 {
return true
}
// 169.254.0.0/16 (链路本地地址)
if ipv4[0] == 169 && ipv4[1] == 254 {
return true
}
return false
}
// 检查IPv6内网地址
// ::1/128 (localhost)
if parsedIP.IsLoopback() {
return true
}
// fc00::/7 (唯一本地地址)
if parsedIP[0]&0xfc == 0xfc {
return true
}
// fe80::/10 (链路本地地址)
if parsedIP[0]&0xfe == 0xfe && parsedIP[1]&0xc0 == 0x80 {
return true
}
return false
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
// 检查文件是否存在
data, err := ioutil.ReadFile("data/stats.json")
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
s.stats = statsData.Stats
// 确保使用当前配置中的EnableDNSSEC值覆盖从文件加载的值
s.stats.DNSSECEnabled = s.config.EnableDNSSEC
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
s.dailyStatsMutex.Lock()
if statsData.DailyStats != nil {
s.dailyStats = statsData.DailyStats
}
s.dailyStatsMutex.Unlock()
s.monthlyStatsMutex.Lock()
if statsData.MonthlyStats != nil {
s.monthlyStats = statsData.MonthlyStats
}
s.monthlyStatsMutex.Unlock()
// 加载客户端统计数据
s.clientStatsMutex.Lock()
if statsData.ClientStats != nil {
s.clientStats = statsData.ClientStats
}
s.clientStatsMutex.Unlock()
logger.Info("统计数据加载成功")
// 加载查询日志
s.loadQueryLogs()
}
// loadQueryLogs 从文件加载查询日志
func (s *Server) loadQueryLogs() {
// 获取绝对路径
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 构建查询日志文件路径
queryLogPath := filepath.Join(filepath.Dir(statsFilePath), "querylog.json")
// 检查文件是否存在
if _, err := os.Stat(queryLogPath); os.IsNotExist(err) {
logger.Info("查询日志文件不存在,将使用空列表", "file", queryLogPath)
return
}
// 读取文件内容
data, err := ioutil.ReadFile(queryLogPath)
if err != nil {
logger.Error("读取查询日志文件失败", "error", err)
return
}
// 解析数据
var logs []QueryLog
err = json.Unmarshal(data, &logs)
if err != nil {
logger.Error("解析查询日志失败", "error", err)
return
}
// 更新查询日志
s.queryLogsMutex.Lock()
s.queryLogs = logs
// 确保日志数量不超过限制
if len(s.queryLogs) > s.maxQueryLogs {
s.queryLogs = s.queryLogs[:s.maxQueryLogs]
}
s.queryLogsMutex.Unlock()
logger.Info("查询日志加载成功", "count", len(logs))
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
// 获取绝对路径以避免工作目录问题
statsFilePath, err := filepath.Abs("data/stats.json")
if err != nil {
logger.Error("获取统计文件绝对路径失败", "path", "data/stats.json", "error", err)
return
}
// 创建数据目录
statsDir := filepath.Dir(statsFilePath)
err = os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "dir", statsDir, "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
s.dailyStatsMutex.RLock()
statsData.DailyStats = make(map[string]int64)
for k, v := range s.dailyStats {
statsData.DailyStats[k] = v
}
s.dailyStatsMutex.RUnlock()
s.monthlyStatsMutex.RLock()
statsData.MonthlyStats = make(map[string]int64)
for k, v := range s.monthlyStats {
statsData.MonthlyStats[k] = v
}
s.monthlyStatsMutex.RUnlock()
// 复制客户端统计数据
s.clientStatsMutex.RLock()
statsData.ClientStats = make(map[string]*ClientStats)
for k, v := range s.clientStats {
statsData.ClientStats[k] = v
}
s.clientStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(statsFilePath, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "file", statsFilePath, "error", err)
return
}
logger.Info("统计数据保存成功", "file", statsFilePath)
// 保存查询日志到文件
s.saveQueryLogs(statsDir)
}
// saveQueryLogs 保存查询日志到文件
func (s *Server) saveQueryLogs(dataDir string) {
// 构建查询日志文件路径
queryLogPath := filepath.Join(dataDir, "querylog.json")
// 获取查询日志数据
s.queryLogsMutex.RLock()
logsCopy := make([]QueryLog, len(s.queryLogs))
copy(logsCopy, s.queryLogs)
s.queryLogsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(logsCopy, "", " ")
if err != nil {
logger.Error("序列化查询日志失败", "error", err)
return
}
// 写入文件
err = os.WriteFile(queryLogPath, jsonData, 0644)
if err != nil {
logger.Error("保存查询日志到文件失败", "file", queryLogPath, "error", err)
return
}
logger.Info("查询日志保存成功", "file", queryLogPath)
}
// startCpuUsageMonitor 启动CPU使用率监控
func (s *Server) startCpuUsageMonitor() {
ticker := time.NewTicker(time.Second * 5) // 每5秒更新一次CPU使用率
defer ticker.Stop()
// 初始化
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// 存储上一次的CPU时间统计
var prevIdle, prevTotal uint64
for {
select {
case <-ticker.C:
// 获取真实的系统级CPU使用率
cpuUsage, err := getSystemCpuUsage(&prevIdle, &prevTotal)
if err != nil {
// 如果获取失败,使用默认值
cpuUsage = 0.0
logger.Error("获取系统CPU使用率失败", "error", err)
}
s.updateStats(func(stats *Stats) {
stats.CpuUsage = cpuUsage
})
case <-s.ctx.Done():
return
}
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息
file, err := os.Open("/proc/stat")
if err != nil {
return 0, err
}
defer file.Close()
var cpuUser, cpuNice, cpuSystem, cpuIdle, cpuIowait, cpuIrq, cpuSoftirq, cpuSteal uint64
_, err = fmt.Fscanf(file, "cpu %d %d %d %d %d %d %d %d",
&cpuUser, &cpuNice, &cpuSystem, &cpuIdle, &cpuIowait, &cpuIrq, &cpuSoftirq, &cpuSteal)
if err != nil {
return 0, err
}
// 计算总的CPU时间
total := cpuUser + cpuNice + cpuSystem + cpuIdle + cpuIowait + cpuIrq + cpuSoftirq + cpuSteal
idle := cpuIdle + cpuIowait
// 第一次调用时,只初始化值,不计算使用率
if *prevTotal == 0 || *prevIdle == 0 {
*prevIdle = idle
*prevTotal = total
return 0, nil
}
// 计算CPU使用率
idleDelta := idle - *prevIdle
totalDelta := total - *prevTotal
utilization := float64(totalDelta-idleDelta) / float64(totalDelta) * 100
// 更新上一次的值
*prevIdle = idle
*prevTotal = total
return utilization, nil
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.SaveInterval <= 0 {
return
}
// 初始化定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", "data/stats.json")
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}