优化修复

This commit is contained in:
Alex Yang
2026-01-03 01:11:42 +08:00
parent 1dd1f15788
commit f247eaeaa8
16 changed files with 1288 additions and 315 deletions

View File

@@ -48,8 +48,6 @@ type ClientStats struct {
LastSeen time.Time
}
// DNSAnswer DNS解析记录
type DNSAnswer struct {
Type string `json:"type"` // 记录类型
@@ -131,8 +129,6 @@ type Server struct {
stopped bool // 服务器是否已经停止
stoppedMutex sync.Mutex // 保护stopped标志的互斥锁
// DNS查询缓存
DnsCache *DNSCache // DNS响应缓存
@@ -205,7 +201,7 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
maxQueryLogs: 10000, // 最大保存10000条日志
saveDone: make(chan struct{}),
stopped: false, // 初始化为未停止状态
// DNS查询缓存初始化
DnsCache: NewDNSCache(cacheTTL),
// 初始化域名DNSSEC状态映射表
@@ -258,8 +254,6 @@ func (s *Server) Start() error {
// 启动自动保存功能
go s.startAutoSave()
// 启动UDP服务
go func() {
logger.Info(fmt.Sprintf("DNS UDP服务器启动监听端口: %d", s.config.Port))
@@ -856,6 +850,29 @@ func mergeResponses(responses []*dns.Msg) *dns.Msg {
mergedResponse.Ns = []dns.RR{}
mergedResponse.Extra = []dns.RR{}
// 重置Rcode为成功除非所有响应都是NXDOMAIN
mergedResponse.Rcode = dns.RcodeSuccess
// 检查是否所有响应都是NXDOMAIN
allNXDOMAIN := true
// 收集所有成功响应的记录
for _, resp := range responses {
if resp == nil {
continue
}
// 如果有任何响应是成功的就不是allNXDOMAIN
if resp.Rcode == dns.RcodeSuccess {
allNXDOMAIN = false
}
}
// 如果所有响应都是NXDOMAIN设置合并响应为NXDOMAIN
if allNXDOMAIN {
mergedResponse.Rcode = dns.RcodeNameError
}
// 使用map存储唯一记录选择最长TTL
// 预分配map容量减少扩容开销
answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses))
@@ -867,46 +884,51 @@ func mergeResponses(responses []*dns.Msg) *dns.Msg {
continue
}
// 合并Answer部分
for _, rr := range resp.Answer {
key := getRecordKey(rr)
if existing, exists := answerMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
// 合并与最终Rcode匹配的响应记录
if (mergedResponse.Rcode == dns.RcodeSuccess && resp.Rcode == dns.RcodeSuccess) ||
(mergedResponse.Rcode == dns.RcodeNameError && resp.Rcode == dns.RcodeNameError) {
// 合并Answer部分
for _, rr := range resp.Answer {
key := getRecordKey(rr)
if existing, exists := answerMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
answerMap[key] = rr
}
} else {
answerMap[key] = rr
}
} else {
answerMap[key] = rr
}
}
// 合并Ns部分
for _, rr := range resp.Ns {
key := getRecordKey(rr)
if existing, exists := nsMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
// 合并Ns部分
for _, rr := range resp.Ns {
key := getRecordKey(rr)
if existing, exists := nsMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
nsMap[key] = rr
}
} else {
nsMap[key] = rr
}
} else {
nsMap[key] = rr
}
}
// 合并Extra部分
for _, rr := range resp.Extra {
// 跳过OPT记录避免重复
if rr.Header().Rrtype == dns.TypeOPT {
continue
}
key := getRecordKey(rr)
if existing, exists := extraMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
// 合并Extra部分
for _, rr := range resp.Extra {
// 跳过OPT记录避免重复
if rr.Header().Rrtype == dns.TypeOPT {
continue
}
key := getRecordKey(rr)
if existing, exists := extraMap[key]; exists {
// 如果存在相同记录选择TTL更长的
if rr.Header().Ttl > existing.Header().Ttl {
extraMap[key] = rr
}
} else {
extraMap[key] = rr
}
} else {
extraMap[key] = rr
}
}
}
@@ -1036,14 +1058,21 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
responses := make(chan serverResponse, len(selectedUpstreamDNS))
var wg sync.WaitGroup
// 向所有上游服务器并行发送请求
// 向所有上游服务器并行发送请求,每个请求带有超时
for _, upstream := range selectedUpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 创建带有超时的resolver
client := &dns.Client{
Net: s.resolver.Net,
UDPSize: s.resolver.UDPSize,
Timeout: defaultTimeout,
}
// 发送请求并获取响应,确保服务器地址包含端口号
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server))
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
responses <- serverResponse{response, rtt, server, err}
}(upstream)
}
@@ -1054,8 +1083,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
close(responses)
}()
// 收集所有有效响应
var validResponses []*dns.Msg
// 收集成功响应和NXDOMAIN响应分开
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
var totalRtt time.Duration
var responseCount int
@@ -1087,9 +1117,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}
}
// 收集有效响应
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
validResponses = append(validResponses, resp.response)
// 收集响应按Rcode分类
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
totalRtt += resp.rtt
responseCount++
@@ -1097,6 +1127,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
if usedDNSServer == "" {
usedDNSServer = resp.server
}
} else if resp.response.Rcode == dns.RcodeNameError {
nxdomainResponses = append(nxdomainResponses, resp.response)
} else {
// 更新备选响应,确保总有一个可用的响应
if resp.response != nil {
@@ -1114,6 +1146,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}
}
// 合并响应优先使用成功响应只有当没有成功响应时才使用NXDOMAIN响应
var validResponses []*dns.Msg
if len(successResponses) > 0 {
validResponses = successResponses
} else {
validResponses = nxdomainResponses
}
// 合并所有有效响应
if len(validResponses) > 0 {
bestResponse = mergeResponses(validResponses)
@@ -1121,11 +1161,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
bestRtt = totalRtt / time.Duration(responseCount)
}
hasBestResponse = true
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses))
// 设置日志的type字段
logType := "success"
if len(successResponses) == 0 {
logType = "nxdomain"
}
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses), "type", logType)
}
case "fastest-ip":
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
// 1. 选择最快的服务器
@@ -1279,7 +1322,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
var fastestServer string
var fastestDnssecServer string
var fastestHasDnssec bool
var validResponses []*dns.Msg
var successResponses []*dns.Msg
var nxdomainResponses []*dns.Msg
// 等待所有请求完成或超时
timer := time.NewTimer(defaultTimeout)
@@ -1301,30 +1345,8 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(resp.response)
// 如果启用了DNSSEC且响应包含DNSSEC记录验证DNSSEC签名
// 但如果域名匹配不验证DNSSEC的模式则跳过验证
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(resp.response)
// 设置AD标志Authenticated Data
resp.response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECQueries++
stats.DNSSECFailed++
})
}
} else if noDNSSEC {
// 对于不验证DNSSEC的域名始终设置AD标志为false
// 对于不验证DNSSEC的域名始终设置AD标志为false
if noDNSSEC {
resp.response.AuthenticatedData = false
}
@@ -1339,8 +1361,12 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 如果响应成功或为NXDOMAIN
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
// 添加到有效响应列表,用于后续合并
validResponses = append(validResponses, resp.response)
// 按Rcode分类添加到不同列表
if resp.response.Rcode == dns.RcodeSuccess {
successResponses = append(successResponses, resp.response)
} else {
nxdomainResponses = append(nxdomainResponses, resp.response)
}
// 快速返回逻辑:找到第一个有效响应或更快的响应
if resp.response.Rcode == dns.RcodeSuccess {
@@ -1493,6 +1519,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
}
doneProcessing:
// 合并响应优先使用成功响应只有当没有成功响应时才使用NXDOMAIN响应
var validResponses []*dns.Msg
if len(successResponses) > 0 {
validResponses = successResponses
} else {
validResponses = nxdomainResponses
}
// 合并所有有效响应,用于缓存
if len(validResponses) > 1 {
mergedResponse := mergeResponses(validResponses)
@@ -2255,7 +2289,6 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
log := QueryLog{
Timestamp: time.Now(),
ClientIP: clientIP,
Location: "", // 客户端IP地理位置由前端处理
Domain: domain,
QueryType: queryType,
ResponseTime: responseTime,
@@ -2644,8 +2677,6 @@ func isPrivateIP(ip string) bool {
return false
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
// 检查文件是否存在
@@ -2908,10 +2939,6 @@ func (s *Server) startCpuUsageMonitor() {
}
}
// getSystemCpuUsage 获取系统CPU使用率
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
// 读取/proc/stat文件获取CPU统计信息