优化请求模式设置

This commit is contained in:
Alex Yang
2025-12-17 22:43:31 +08:00
parent 5d0fb6d4fe
commit 0f0aa76662
33 changed files with 692285 additions and 24 deletions

View File

@@ -73,6 +73,15 @@ type StatsData struct {
LastSaved time.Time `json:"lastSaved"`
}
// ServerStats 服务器统计信息
type ServerStats struct {
SuccessCount int64 // 成功查询次数
FailureCount int64 // 失败查询次数
LastResponse time.Time // 最后响应时间
ResponseTime time.Duration // 平均响应时间
ConnectionSpeed time.Duration // TCP连接速度
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
@@ -116,6 +125,10 @@ type Server struct {
// 域名DNSSEC状态映射表
domainDNSSECStatus map[string]bool // 域名到DNSSEC状态的映射
// 上游服务器状态跟踪
serverStats map[string]*ServerStats // 服务器地址到状态的映射
serverStatsMutex sync.RWMutex // 保护服务器状态的互斥锁
}
// Stats DNS服务器统计信息
@@ -187,6 +200,8 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
dnsCache: NewDNSCache(cacheTTL),
// 初始化域名DNSSEC状态映射表
domainDNSSECStatus: make(map[string]bool),
// 初始化服务器状态跟踪
serverStats: make(map[string]*ServerStats),
}
// 加载已保存的统计数据
@@ -690,21 +705,35 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 根据查询模式处理请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式
// 并行请求模式 - 优化版:添加超时处理和快速响应返回
responses := make(chan serverResponse, len(s.config.UpstreamDNS))
var wg sync.WaitGroup
// 超时上下文
timeoutCtx, cancel := context.WithTimeout(s.ctx, time.Duration(s.config.Timeout)*time.Millisecond)
defer cancel()
// 向所有上游服务器并行发送请求
for _, upstream := range s.config.UpstreamDNS {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
responses <- serverResponse{response, rtt, server, err}
select {
case responses <- serverResponse{response, rtt, server, err}:
// 成功发送响应
case <-timeoutCtx.Done():
// 超时,忽略此响应
logger.Debug("并行请求超时", "server", server, "domain", domain)
return
}
}(upstream)
}
// 等待所有请求完成
// 等待所有请求完成或超时
go func() {
wg.Wait()
close(responses)
@@ -713,6 +742,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 设置递归可用标志
resp.response.RecursionAvailable = true
@@ -765,14 +797,22 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
case "loadbalance":
// 负载均衡模式 - 目前使用简单的轮询,后续可以扩展为更复杂的算法
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
// 负载均衡模式 - 使用加权随机选择算法
// 1. 选择一个加权随机服务器
selectedServer := s.selectWeightedRandomServer(s.config.UpstreamDNS)
if selectedServer != "" {
response, rtt, err := s.resolver.Exchange(r, selectedServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
@@ -810,14 +850,13 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", upstream, "rtt", rtt)
break // 找到带DNSSEC的响应立即返回
} else if !hasBestResponse {
// 没有带DNSSEC的响应时保存第一个成功响应
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt)
} else {
// 没有带DNSSEC的响应时保存成功响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", upstream, "rtt", rtt)
logger.Debug("找到最佳响应", "domain", domain, "server", selectedServer, "rtt", rtt)
}
// 保存为备选响应
if !hasBackup {
@@ -826,14 +865,22 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedServer, false, 0)
}
}
case "fastest-ip":
// 最快的IP地址模式 - 目前使用简单的顺序请求后续可以扩展为测量TCP连接速度
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快服务器
// 1. 选择最快的服务器
fastestServer := s.selectFastestServer(s.config.UpstreamDNS)
if fastestServer != "" {
response, rtt, err := s.resolver.Exchange(r, fastestServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
@@ -871,15 +918,13 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", upstream, "rtt", rtt)
break // 找到带DNSSEC的响应立即返回
} else if !hasBestResponse {
// 没有带DNSSEC的响应时保存第一个成功响应
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
} else {
// 没有带DNSSEC的响应时保存成功响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
logger.Debug("找到最佳响应", "domain", domain, "server", upstream, "rtt", rtt)
break // 找到响应,立即返回
logger.Debug("找到最佳响应", "domain", domain, "server", fastestServer, "rtt", rtt)
}
// 保存为备选响应
if !hasBackup {
@@ -888,6 +933,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
}
}
@@ -983,21 +1031,35 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 根据查询模式处理DNSSEC服务器请求
switch s.config.QueryMode {
case "parallel":
// 并行请求模式
// 并行请求模式 - 优化版:添加超时处理和服务器统计
responses := make(chan serverResponse, len(dnssecServers))
var wg sync.WaitGroup
// 超时上下文
timeoutCtx, cancel := context.WithTimeout(s.ctx, time.Duration(s.config.Timeout)*time.Millisecond)
defer cancel()
// 向所有DNSSEC服务器并行发送请求
for _, dnssecServer := range dnssecServers {
wg.Add(1)
go func(server string) {
defer wg.Done()
// 发送请求并获取响应
response, rtt, err := s.resolver.Exchange(r, server)
responses <- serverResponse{response, rtt, server, err}
select {
case responses <- serverResponse{response, rtt, server, err}:
// 成功发送响应
case <-timeoutCtx.Done():
// 超时,忽略此响应
logger.Debug("DNSSEC并行请求超时", "server", server, "domain", domain)
return
}
}(dnssecServer)
}
// 等待所有请求完成
// 等待所有请求完成或超时
go func() {
wg.Wait()
close(responses)
@@ -1006,6 +1068,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
// 处理所有响应
for resp := range responses {
if resp.error == nil && resp.response != nil {
// 更新服务器统计信息
s.updateServerStats(resp.server, true, resp.rtt)
// 设置递归可用标志
resp.response.RecursionAvailable = true
@@ -1051,6 +1116,131 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(resp.server, false, 0)
}
}
case "loadbalance":
// 负载均衡模式 - 使用加权随机选择算法
// 1. 选择一个加权随机DNSSEC服务器
selectedServer := s.selectWeightedRandomServer(dnssecServers)
if selectedServer != "" {
response, rtt, err := s.resolver.Exchange(r, selectedServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(selectedServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", selectedServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(selectedServer, false, 0)
}
}
case "fastest-ip":
// 最快的IP地址模式 - 使用TCP连接速度测量选择最快DNSSEC服务器
// 1. 选择最快的DNSSEC服务器
fastestServer := s.selectFastestServer(dnssecServers)
if fastestServer != "" {
response, rtt, err := s.resolver.Exchange(r, fastestServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(fastestServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
// 检查是否包含DNSSEC记录
containsDNSSEC := s.hasDNSSECRecords(response)
if response.Rcode == dns.RcodeSuccess {
// 验证DNSSEC记录
signatureValid := s.verifyDNSSEC(response)
// 设置AD标志Authenticated Data
response.AuthenticatedData = signatureValid
if signatureValid {
// 更新DNSSEC验证成功计数
s.updateStats(func(stats *Stats) {
stats.DNSSECSuccess++
})
} else {
// 更新DNSSEC验证失败计数
s.updateStats(func(stats *Stats) {
stats.DNSSECFailed++
})
}
// 优先使用DNSSEC专用服务器的响应尤其是带有DNSSEC记录的
if containsDNSSEC {
// 即使之前有最佳响应也优先使用DNSSEC专用服务器的DNSSEC响应
bestResponse = response
bestRtt = rtt
hasBestResponse = true
hasDNSSECResponse = true
logger.Debug("DNSSEC专用服务器返回带DNSSEC的响应优先使用", "domain", domain, "server", fastestServer, "rtt", rtt)
}
// 注意如果DNSSEC专用服务器返回的响应不包含DNSSEC记录
// 我们不会覆盖之前从upstreamDNS获取的响应
// 这符合"本地解析指的是直接使用上游服务器upstreamDNS进行解析, 而不是dnssecUpstreamDNS"的要求
// 更新备选响应
if !hasBackup {
backupResponse = response
backupRtt = rtt
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(fastestServer, false, 0)
}
}
@@ -1059,6 +1249,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
for _, dnssecServer := range dnssecServers {
response, rtt, err := s.resolver.Exchange(r, dnssecServer)
if err == nil && response != nil {
// 更新服务器统计信息
s.updateServerStats(dnssecServer, true, rtt)
// 设置递归可用标志
response.RecursionAvailable = true
@@ -1105,6 +1298,9 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
hasBackup = true
}
}
} else {
// 更新服务器统计信息(失败)
s.updateServerStats(dnssecServer, false, 0)
}
}
}
@@ -1397,6 +1593,194 @@ func (s *Server) updateResolvedDomainStats(domain string) {
}
}
// getServerStats 获取服务器统计信息,如果不存在则创建
func (s *Server) getServerStats(server string) *ServerStats {
s.serverStatsMutex.RLock()
stats, exists := s.serverStats[server]
s.serverStatsMutex.RUnlock()
if !exists {
// 创建新的服务器统计信息
stats = &ServerStats{
SuccessCount: 0,
FailureCount: 0,
LastResponse: time.Now(),
ResponseTime: 0,
ConnectionSpeed: 0,
}
// 加锁更新服务器统计信息
s.serverStatsMutex.Lock()
s.serverStats[server] = stats
s.serverStatsMutex.Unlock()
}
return stats
}
// updateServerStats 更新服务器统计信息
func (s *Server) updateServerStats(server string, success bool, rtt time.Duration) {
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
defer s.serverStatsMutex.Unlock()
// 更新统计信息
stats.LastResponse = time.Now()
if success {
stats.SuccessCount++
// 更新平均响应时间(简单移动平均)
// 将所有值转换为纳秒进行计算然后再转换回Duration
if stats.SuccessCount == 1 {
// 第一次成功,直接使用当前响应时间
stats.ResponseTime = rtt
} else {
// 使用纳秒进行计算以避免类型不匹配
prevTotal := stats.ResponseTime.Nanoseconds() * (stats.SuccessCount - 1)
newTotal := prevTotal + rtt.Nanoseconds()
stats.ResponseTime = time.Duration(newTotal / stats.SuccessCount)
}
} else {
stats.FailureCount++
}
}
// selectWeightedRandomServer 加权随机选择服务器
func (s *Server) selectWeightedRandomServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 计算每个服务器的权重
type serverWeight struct {
server string
weight int64
}
var totalWeight int64
weights := make([]serverWeight, 0, len(servers))
for _, server := range servers {
stats := s.getServerStats(server)
// 计算权重:成功次数 - 失败次数 * 2失败权重更高
// 确保权重至少为1
weight := stats.SuccessCount - stats.FailureCount*2
if weight < 1 {
weight = 1
}
weights = append(weights, serverWeight{server, weight})
totalWeight += weight
}
// 随机选择一个权重
random := time.Now().UnixNano() % totalWeight
if random < 0 {
random += totalWeight
}
// 选择对应的服务器
var currentWeight int64
for _, sw := range weights {
currentWeight += sw.weight
if random < currentWeight {
return sw.server
}
}
// 兜底返回第一个服务器
return servers[0]
}
// measureServerSpeed 测量服务器TCP连接速度
func (s *Server) measureServerSpeed(server string) time.Duration {
// 提取服务器地址和端口
addr := server
if !strings.Contains(server, ":") {
addr = server + ":53"
}
// 测量TCP连接时间
startTime := time.Now()
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
// 连接失败,返回最大持续时间
return 2 * time.Second
}
defer conn.Close()
// 计算连接建立时间
connTime := time.Since(startTime)
// 更新服务器连接速度
stats := s.getServerStats(server)
s.serverStatsMutex.Lock()
// 使用指数移动平均更新连接速度
stats.ConnectionSpeed = (stats.ConnectionSpeed*3 + connTime) / 4
s.serverStatsMutex.Unlock()
return connTime
}
// selectFastestServer 选择连接速度最快的服务器
func (s *Server) selectFastestServer(servers []string) string {
if len(servers) == 0 {
return ""
}
if len(servers) == 1 {
return servers[0]
}
// 并行测量所有服务器的速度
type speedResult struct {
server string
speed time.Duration
}
results := make(chan speedResult, len(servers))
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
go func(srv string) {
defer wg.Done()
speed := s.measureServerSpeed(srv)
results <- speedResult{srv, speed}
}(server)
}
// 等待所有测量完成
go func() {
wg.Wait()
close(results)
}()
// 找出最快的服务器
var fastestServer string
var fastestSpeed time.Duration = 2 * time.Second
for result := range results {
if result.speed < fastestSpeed {
fastestSpeed = result.speed
fastestServer = result.server
}
}
// 如果没有找到最快服务器(理论上不会发生),返回第一个服务器
if fastestServer == "" {
fastestServer = servers[0]
}
return fastestServer
}
// updateStats 更新统计信息
func (s *Server) updateStats(update func(*Stats)) {
s.statsMutex.Lock()