优化修复
This commit is contained in:
12
dns/cache.go
12
dns/cache.go
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
// DNSCacheItem 表示缓存中的DNS响应项
|
||||
type DNSCacheItem struct {
|
||||
Response *dns.Msg // DNS响应消息
|
||||
Expiry time.Time // 过期时间
|
||||
HasDNSSEC bool // 是否包含DNSSEC记录
|
||||
Response *dns.Msg // DNS响应消息
|
||||
Expiry time.Time // 过期时间
|
||||
HasDNSSEC bool // 是否包含DNSSEC记录
|
||||
}
|
||||
|
||||
// DNSCache DNS缓存结构
|
||||
@@ -21,7 +21,7 @@ type DNSCache struct {
|
||||
defaultTTL time.Duration // 默认缓存TTL
|
||||
maxSize int // 最大缓存条目数
|
||||
// 使用链表结构来跟踪缓存条目的访问顺序,用于LRU淘汰
|
||||
accessList []string // 记录访问顺序,最新访问的放在最后
|
||||
accessList []string // 记录访问顺序,最新访问的放在最后
|
||||
}
|
||||
|
||||
// NewDNSCache 创建新的DNS缓存实例
|
||||
@@ -109,8 +109,8 @@ func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.D
|
||||
|
||||
key := cacheKey(qName, qType)
|
||||
item := &DNSCacheItem{
|
||||
Response: response.Copy(), // 复制响应以避免外部修改
|
||||
Expiry: time.Now().Add(ttl),
|
||||
Response: response.Copy(), // 复制响应以避免外部修改
|
||||
Expiry: time.Now().Add(ttl),
|
||||
HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志
|
||||
}
|
||||
|
||||
|
||||
189
dns/server.go
189
dns/server.go
@@ -48,8 +48,6 @@ type ClientStats struct {
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
|
||||
|
||||
// DNSAnswer DNS解析记录
|
||||
type DNSAnswer struct {
|
||||
Type string `json:"type"` // 记录类型
|
||||
@@ -131,8 +129,6 @@ type Server struct {
|
||||
stopped bool // 服务器是否已经停止
|
||||
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
|
||||
|
||||
|
||||
|
||||
// DNS查询缓存
|
||||
DnsCache *DNSCache // DNS响应缓存
|
||||
|
||||
@@ -205,7 +201,7 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
|
||||
maxQueryLogs: 10000, // 最大保存10000条日志
|
||||
saveDone: make(chan struct{}),
|
||||
stopped: false, // 初始化为未停止状态
|
||||
|
||||
|
||||
// DNS查询缓存初始化
|
||||
DnsCache: NewDNSCache(cacheTTL),
|
||||
// 初始化域名DNSSEC状态映射表
|
||||
@@ -258,8 +254,6 @@ func (s *Server) Start() error {
|
||||
// 启动自动保存功能
|
||||
go s.startAutoSave()
|
||||
|
||||
|
||||
|
||||
// 启动UDP服务
|
||||
go func() {
|
||||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||||
@@ -856,6 +850,29 @@ func mergeResponses(responses []*dns.Msg) *dns.Msg {
|
||||
mergedResponse.Ns = []dns.RR{}
|
||||
mergedResponse.Extra = []dns.RR{}
|
||||
|
||||
// 重置Rcode为成功,除非所有响应都是NXDOMAIN
|
||||
mergedResponse.Rcode = dns.RcodeSuccess
|
||||
|
||||
// 检查是否所有响应都是NXDOMAIN
|
||||
allNXDOMAIN := true
|
||||
|
||||
// 收集所有成功响应的记录
|
||||
for _, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果有任何响应是成功的,就不是allNXDOMAIN
|
||||
if resp.Rcode == dns.RcodeSuccess {
|
||||
allNXDOMAIN = false
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有响应都是NXDOMAIN,设置合并响应为NXDOMAIN
|
||||
if allNXDOMAIN {
|
||||
mergedResponse.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
// 使用map存储唯一记录,选择最长TTL
|
||||
// 预分配map容量,减少扩容开销
|
||||
answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses))
|
||||
@@ -867,46 +884,51 @@ func mergeResponses(responses []*dns.Msg) *dns.Msg {
|
||||
continue
|
||||
}
|
||||
|
||||
// 合并Answer部分
|
||||
for _, rr := range resp.Answer {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := answerMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
// 只合并与最终Rcode匹配的响应记录
|
||||
if (mergedResponse.Rcode == dns.RcodeSuccess && resp.Rcode == dns.RcodeSuccess) ||
|
||||
(mergedResponse.Rcode == dns.RcodeNameError && resp.Rcode == dns.RcodeNameError) {
|
||||
|
||||
// 合并Answer部分
|
||||
for _, rr := range resp.Answer {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := answerMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
answerMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
answerMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
answerMap[key] = rr
|
||||
}
|
||||
}
|
||||
|
||||
// 合并Ns部分
|
||||
for _, rr := range resp.Ns {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := nsMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
// 合并Ns部分
|
||||
for _, rr := range resp.Ns {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := nsMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
nsMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
nsMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
nsMap[key] = rr
|
||||
}
|
||||
}
|
||||
|
||||
// 合并Extra部分
|
||||
for _, rr := range resp.Extra {
|
||||
// 跳过OPT记录,避免重复
|
||||
if rr.Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := extraMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
// 合并Extra部分
|
||||
for _, rr := range resp.Extra {
|
||||
// 跳过OPT记录,避免重复
|
||||
if rr.Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := extraMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
extraMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
extraMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
extraMap[key] = rr
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1036,14 +1058,21 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
responses := make(chan serverResponse, len(selectedUpstreamDNS))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// 向所有上游服务器并行发送请求
|
||||
// 向所有上游服务器并行发送请求,每个请求带有超时
|
||||
for _, upstream := range selectedUpstreamDNS {
|
||||
wg.Add(1)
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
|
||||
// 创建带有超时的resolver
|
||||
client := &dns.Client{
|
||||
Net: s.resolver.Net,
|
||||
UDPSize: s.resolver.UDPSize,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
||||
// 发送请求并获取响应,确保服务器地址包含端口号
|
||||
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server))
|
||||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
|
||||
responses <- serverResponse{response, rtt, server, err}
|
||||
}(upstream)
|
||||
}
|
||||
@@ -1054,8 +1083,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
close(responses)
|
||||
}()
|
||||
|
||||
// 收集所有有效响应
|
||||
var validResponses []*dns.Msg
|
||||
// 收集成功响应和NXDOMAIN响应分开
|
||||
var successResponses []*dns.Msg
|
||||
var nxdomainResponses []*dns.Msg
|
||||
var totalRtt time.Duration
|
||||
var responseCount int
|
||||
|
||||
@@ -1087,9 +1117,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}
|
||||
}
|
||||
|
||||
// 收集有效响应
|
||||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||||
validResponses = append(validResponses, resp.response)
|
||||
// 收集响应,按Rcode分类
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
successResponses = append(successResponses, resp.response)
|
||||
totalRtt += resp.rtt
|
||||
responseCount++
|
||||
|
||||
@@ -1097,6 +1127,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
if usedDNSServer == "" {
|
||||
usedDNSServer = resp.server
|
||||
}
|
||||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||||
nxdomainResponses = append(nxdomainResponses, resp.response)
|
||||
} else {
|
||||
// 更新备选响应,确保总有一个可用的响应
|
||||
if resp.response != nil {
|
||||
@@ -1114,6 +1146,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}
|
||||
}
|
||||
|
||||
// 合并响应:优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应
|
||||
var validResponses []*dns.Msg
|
||||
if len(successResponses) > 0 {
|
||||
validResponses = successResponses
|
||||
} else {
|
||||
validResponses = nxdomainResponses
|
||||
}
|
||||
|
||||
// 合并所有有效响应
|
||||
if len(validResponses) > 0 {
|
||||
bestResponse = mergeResponses(validResponses)
|
||||
@@ -1121,11 +1161,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
bestRtt = totalRtt / time.Duration(responseCount)
|
||||
}
|
||||
hasBestResponse = true
|
||||
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses))
|
||||
// 设置日志的type字段
|
||||
logType := "success"
|
||||
if len(successResponses) == 0 {
|
||||
logType = "nxdomain"
|
||||
}
|
||||
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses), "type", logType)
|
||||
}
|
||||
|
||||
|
||||
|
||||
case "fastest-ip":
|
||||
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
|
||||
// 1. 选择最快的服务器
|
||||
@@ -1279,7 +1322,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
var fastestServer string
|
||||
var fastestDnssecServer string
|
||||
var fastestHasDnssec bool
|
||||
var validResponses []*dns.Msg
|
||||
var successResponses []*dns.Msg
|
||||
var nxdomainResponses []*dns.Msg
|
||||
|
||||
// 等待所有请求完成或超时
|
||||
timer := time.NewTimer(defaultTimeout)
|
||||
@@ -1301,30 +1345,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
// 检查是否包含DNSSEC记录
|
||||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||||
|
||||
// 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名
|
||||
// 但如果域名匹配不验证DNSSEC的模式,则跳过验证
|
||||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(resp.response)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
resp.response.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
} else if noDNSSEC {
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
if noDNSSEC {
|
||||
resp.response.AuthenticatedData = false
|
||||
}
|
||||
|
||||
@@ -1339,8 +1361,12 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
|
||||
// 如果响应成功或为NXDOMAIN
|
||||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||||
// 添加到有效响应列表,用于后续合并
|
||||
validResponses = append(validResponses, resp.response)
|
||||
// 按Rcode分类添加到不同列表
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
successResponses = append(successResponses, resp.response)
|
||||
} else {
|
||||
nxdomainResponses = append(nxdomainResponses, resp.response)
|
||||
}
|
||||
|
||||
// 快速返回逻辑:找到第一个有效响应或更快的响应
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
@@ -1493,6 +1519,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}
|
||||
|
||||
doneProcessing:
|
||||
// 合并响应,优先使用成功响应,只有当没有成功响应时才使用NXDOMAIN响应
|
||||
var validResponses []*dns.Msg
|
||||
if len(successResponses) > 0 {
|
||||
validResponses = successResponses
|
||||
} else {
|
||||
validResponses = nxdomainResponses
|
||||
}
|
||||
|
||||
// 合并所有有效响应,用于缓存
|
||||
if len(validResponses) > 1 {
|
||||
mergedResponse := mergeResponses(validResponses)
|
||||
@@ -2255,7 +2289,6 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
|
||||
log := QueryLog{
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Location: "", // 客户端IP地理位置由前端处理
|
||||
Domain: domain,
|
||||
QueryType: queryType,
|
||||
ResponseTime: responseTime,
|
||||
@@ -2644,8 +2677,6 @@ func isPrivateIP(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
|
||||
// loadStatsData 从文件加载统计数据
|
||||
func (s *Server) loadStatsData() {
|
||||
// 检查文件是否存在
|
||||
@@ -2908,10 +2939,6 @@ func (s *Server) startCpuUsageMonitor() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// getSystemCpuUsage 获取系统CPU使用率
|
||||
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
|
||||
// 读取/proc/stat文件获取CPU统计信息
|
||||
|
||||
Reference in New Issue
Block a user