优化修复
This commit is contained in:
85
dns/cache.go
85
dns/cache.go
@@ -19,6 +19,9 @@ type DNSCache struct {
|
||||
cache map[string]*DNSCacheItem // 缓存映射表
|
||||
mutex sync.RWMutex // 读写锁,保护缓存
|
||||
defaultTTL time.Duration // 默认缓存TTL
|
||||
maxSize int // 最大缓存条目数
|
||||
// 使用链表结构来跟踪缓存条目的访问顺序,用于LRU淘汰
|
||||
accessList []string // 记录访问顺序,最新访问的放在最后
|
||||
}
|
||||
|
||||
// NewDNSCache 创建新的DNS缓存实例
|
||||
@@ -26,6 +29,8 @@ func NewDNSCache(defaultTTL time.Duration) *DNSCache {
|
||||
cache := &DNSCache{
|
||||
cache: make(map[string]*DNSCacheItem),
|
||||
defaultTTL: defaultTTL,
|
||||
maxSize: 10000, // 默认最大缓存10000条记录
|
||||
accessList: make([]string, 0, 10000),
|
||||
}
|
||||
|
||||
// 启动缓存清理协程
|
||||
@@ -110,32 +115,70 @@ func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.D
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// 如果条目已存在,先从访问列表中移除
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
// 移除旧位置
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 将新条目添加到访问列表末尾
|
||||
c.accessList = append(c.accessList, key)
|
||||
c.cache[key] = item
|
||||
c.mutex.Unlock()
|
||||
|
||||
// 检查是否超过最大大小限制,如果超过则移除最久未使用的条目
|
||||
if len(c.cache) > c.maxSize {
|
||||
// 最久未使用的条目是访问列表的第一个
|
||||
oldestKey := c.accessList[0]
|
||||
// 从缓存和访问列表中移除
|
||||
delete(c.cache, oldestKey)
|
||||
c.accessList = c.accessList[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取缓存项
|
||||
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
|
||||
key := cacheKey(qName, qType)
|
||||
|
||||
c.mutex.RLock()
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
item, found := c.cache[key]
|
||||
if !found {
|
||||
c.mutex.RUnlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(item.Expiry) {
|
||||
c.mutex.RUnlock()
|
||||
// 过期了,删除缓存项(在写锁中)
|
||||
c.delete(key)
|
||||
// 过期了,删除缓存项
|
||||
delete(c.cache, key)
|
||||
// 从访问列表中移除
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 将访问的条目移动到访问列表末尾(标记为最近使用)
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
// 移除旧位置
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
// 添加到末尾
|
||||
c.accessList = append(c.accessList, key)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 返回缓存的响应副本
|
||||
response := item.Response.Copy()
|
||||
c.mutex.RUnlock()
|
||||
|
||||
return response, true
|
||||
}
|
||||
@@ -143,14 +186,24 @@ func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
|
||||
// delete 删除缓存项
|
||||
func (c *DNSCache) delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// 从缓存中删除
|
||||
delete(c.cache, key)
|
||||
c.mutex.Unlock()
|
||||
// 从访问列表中移除
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear 清空缓存
|
||||
func (c *DNSCache) Clear() {
|
||||
c.mutex.Lock()
|
||||
c.cache = make(map[string]*DNSCacheItem)
|
||||
c.accessList = make([]string, 0, c.maxSize) // 重置访问列表
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -178,9 +231,23 @@ func (c *DNSCache) cleanupExpired() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// 收集所有过期的键
|
||||
var expiredKeys []string
|
||||
for key, item := range c.cache {
|
||||
if now.After(item.Expiry) {
|
||||
delete(c.cache, key)
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除过期的缓存项
|
||||
for _, key := range expiredKeys {
|
||||
delete(c.cache, key)
|
||||
// 从访问列表中移除
|
||||
for i, k := range c.accessList {
|
||||
if k == key {
|
||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
772
dns/server.go
772
dns/server.go
@@ -269,6 +269,9 @@ func (s *Server) Start() error {
|
||||
// 启动自动保存功能
|
||||
go s.startAutoSave()
|
||||
|
||||
// 启动IP地理位置缓存清理协程
|
||||
go s.startIPGeolocationCacheCleanup()
|
||||
|
||||
// 启动UDP服务
|
||||
go func() {
|
||||
logger.Info(fmt.Sprintf("DNS UDP服务器启动,监听端口: %d", s.config.Port))
|
||||
@@ -762,6 +765,185 @@ type serverResponse struct {
|
||||
error error
|
||||
}
|
||||
|
||||
// recordKey 用于唯一标识DNS记录的结构体
|
||||
type recordKey struct {
|
||||
name string
|
||||
rtype uint16
|
||||
class uint16
|
||||
data string
|
||||
}
|
||||
|
||||
// getRecordKey 获取DNS记录的唯一标识
|
||||
func getRecordKey(rr dns.RR) recordKey {
|
||||
// 对于同一域名的同一类型记录,只保留一个,选择最长TTL
|
||||
// 所以对于A、AAAA、CNAME等记录,只使用name、rtype、class作为键
|
||||
// 对于MX记录,还需要考虑Preference字段
|
||||
// 对于TXT记录,需要考虑实际文本内容
|
||||
// 对于NS记录,需要考虑目标服务器
|
||||
|
||||
switch rr.Header().Rrtype {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypePTR:
|
||||
// 对于A、AAAA、CNAME、PTR记录,同一域名只保留一个
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: "",
|
||||
}
|
||||
case dns.TypeMX:
|
||||
// 对于MX记录,同一域名的同一Preference只保留一个
|
||||
if mx, ok := rr.(*dns.MX); ok {
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: fmt.Sprintf("%d", mx.Preference),
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
// 对于TXT记录,需要考虑实际文本内容
|
||||
if txt, ok := rr.(*dns.TXT); ok {
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: strings.Join(txt.Txt, " "),
|
||||
}
|
||||
}
|
||||
case dns.TypeNS:
|
||||
// 对于NS记录,需要考虑目标服务器
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: ns.Ns,
|
||||
}
|
||||
}
|
||||
case dns.TypeSOA:
|
||||
// 对于SOA记录,同一域名只保留一个
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: "",
|
||||
}
|
||||
}
|
||||
|
||||
// 对于其他类型,使用原始rr.String(),但移除TTL部分
|
||||
parts := strings.Split(rr.String(), " ")
|
||||
if len(parts) >= 5 {
|
||||
// 跳过TTL字段(第3个字段)
|
||||
data := strings.Join(append(parts[:2], parts[3:]...), " ")
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
return recordKey{
|
||||
name: rr.Header().Name,
|
||||
rtype: rr.Header().Rrtype,
|
||||
class: rr.Header().Class,
|
||||
data: rr.String(),
|
||||
}
|
||||
}
|
||||
|
||||
// mergeResponses 合并多个DNS响应
|
||||
func mergeResponses(responses []*dns.Msg) *dns.Msg {
|
||||
if len(responses) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果只有一个响应,直接返回,避免不必要的合并操作
|
||||
if len(responses) == 1 {
|
||||
return responses[0].Copy()
|
||||
}
|
||||
|
||||
// 使用第一个响应作为基础
|
||||
mergedResponse := responses[0].Copy()
|
||||
mergedResponse.Answer = []dns.RR{}
|
||||
mergedResponse.Ns = []dns.RR{}
|
||||
mergedResponse.Extra = []dns.RR{}
|
||||
|
||||
// 使用map存储唯一记录,选择最长TTL
|
||||
// 预分配map容量,减少扩容开销
|
||||
answerMap := make(map[recordKey]dns.RR, len(responses[0].Answer)*len(responses))
|
||||
nsMap := make(map[recordKey]dns.RR, len(responses[0].Ns)*len(responses))
|
||||
extraMap := make(map[recordKey]dns.RR, len(responses[0].Extra)*len(responses))
|
||||
|
||||
for _, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 合并Answer部分
|
||||
for _, rr := range resp.Answer {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := answerMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
answerMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
answerMap[key] = rr
|
||||
}
|
||||
}
|
||||
|
||||
// 合并Ns部分
|
||||
for _, rr := range resp.Ns {
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := nsMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
nsMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
nsMap[key] = rr
|
||||
}
|
||||
}
|
||||
|
||||
// 合并Extra部分
|
||||
for _, rr := range resp.Extra {
|
||||
// 跳过OPT记录,避免重复
|
||||
if rr.Header().Rrtype == dns.TypeOPT {
|
||||
continue
|
||||
}
|
||||
key := getRecordKey(rr)
|
||||
if existing, exists := extraMap[key]; exists {
|
||||
// 如果存在相同记录,选择TTL更长的
|
||||
if rr.Header().Ttl > existing.Header().Ttl {
|
||||
extraMap[key] = rr
|
||||
}
|
||||
} else {
|
||||
extraMap[key] = rr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 预分配切片容量,减少扩容开销
|
||||
mergedResponse.Answer = make([]dns.RR, 0, len(answerMap))
|
||||
mergedResponse.Ns = make([]dns.RR, 0, len(nsMap))
|
||||
mergedResponse.Extra = make([]dns.RR, 0, len(extraMap))
|
||||
|
||||
// 将map转换回切片
|
||||
for _, rr := range answerMap {
|
||||
mergedResponse.Answer = append(mergedResponse.Answer, rr)
|
||||
}
|
||||
|
||||
for _, rr := range nsMap {
|
||||
mergedResponse.Ns = append(mergedResponse.Ns, rr)
|
||||
}
|
||||
|
||||
for _, rr := range extraMap {
|
||||
mergedResponse.Extra = append(mergedResponse.Extra, rr)
|
||||
}
|
||||
|
||||
return mergedResponse
|
||||
}
|
||||
|
||||
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
|
||||
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration, string, string) {
|
||||
// 始终支持EDNS
|
||||
@@ -856,10 +1038,13 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
var usedDNSServer string
|
||||
var usedDNSSECServer string
|
||||
|
||||
// 使用配置中的超时时间
|
||||
defaultTimeout := time.Duration(s.config.QueryTimeout) * time.Millisecond
|
||||
|
||||
// 根据查询模式处理请求
|
||||
switch s.config.QueryMode {
|
||||
case "parallel":
|
||||
// 并行请求模式 - 优化版:添加超时处理和快速响应返回
|
||||
// 并行请求模式 - 收集所有响应并合并
|
||||
responses := make(chan serverResponse, len(selectedUpstreamDNS))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -881,7 +1066,12 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
close(responses)
|
||||
}()
|
||||
|
||||
// 处理所有响应,实现快速响应返回
|
||||
// 收集所有有效响应
|
||||
var validResponses []*dns.Msg
|
||||
var totalRtt time.Duration
|
||||
var responseCount int
|
||||
|
||||
// 处理所有响应
|
||||
for resp := range responses {
|
||||
if resp.error == nil && resp.response != nil {
|
||||
// 更新服务器统计信息
|
||||
@@ -890,33 +1080,17 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
// 检查是否包含DNSSEC记录
|
||||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||||
|
||||
// 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名
|
||||
// 但如果域名匹配不验证DNSSEC的模式,则跳过验证
|
||||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(resp.response)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
resp.response.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
} else if noDNSSEC {
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
if noDNSSEC {
|
||||
resp.response.AuthenticatedData = false
|
||||
}
|
||||
|
||||
// 只对将要返回的响应进行DNSSEC验证,减少开销
|
||||
// 这里只设置containsDNSSEC标志,实际验证在确定返回响应后进行
|
||||
if containsDNSSEC && s.config.EnableDNSSEC && !noDNSSEC {
|
||||
// 暂时不验证,只标记
|
||||
}
|
||||
|
||||
// 检查当前服务器是否是DNSSEC专用服务器
|
||||
for _, dnssecServer := range dnssecServers {
|
||||
if dnssecServer == resp.server {
|
||||
@@ -925,111 +1099,43 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}
|
||||
}
|
||||
|
||||
// 检查当前服务器是否是用户配置的上游DNS服务器
|
||||
isUserUpstream := false
|
||||
for _, userServer := range s.config.UpstreamDNS {
|
||||
if userServer == resp.server {
|
||||
isUserUpstream = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// 收集有效响应
|
||||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||||
validResponses = append(validResponses, resp.response)
|
||||
totalRtt += resp.rtt
|
||||
responseCount++
|
||||
|
||||
// 处理响应,优先选择用户配置的主DNS服务器
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
// 成功响应,优先使用
|
||||
if isUserUpstream {
|
||||
// 用户配置的主DNS服务器响应,直接设置为最佳响应
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
hasDNSSECResponse = containsDNSSEC
|
||||
// 记录使用的服务器
|
||||
if usedDNSServer == "" {
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("使用用户配置的上游服务器响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:用户配置的主DNS服务器响应,立即返回
|
||||
continue
|
||||
} else if containsDNSSEC {
|
||||
// 非用户配置服务器,但有DNSSEC记录
|
||||
if !hasBestResponse || !isUserUpstream {
|
||||
// 如果还没有最佳响应,或者当前最佳响应不是用户配置的服务器,则更新
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
hasDNSSECResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:找到带DNSSEC的响应,立即返回
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// 非用户配置服务器,没有DNSSEC记录
|
||||
if !hasBestResponse {
|
||||
// 如果还没有最佳响应,设置为最佳响应
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:第一次找到成功响应,立即返回
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||||
// NXDOMAIN响应
|
||||
if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError {
|
||||
// 如果还没有最佳响应,或者最佳响应也是NXDOMAIN
|
||||
if isUserUpstream {
|
||||
// 用户配置的服务器,直接使用
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("使用用户配置的上游服务器NXDOMAIN响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:用户配置的服务器NXDOMAIN响应,立即返回
|
||||
continue
|
||||
} else if !hasBestResponse || resp.rtt < bestRtt {
|
||||
// 非用户配置服务器,选择更快的响应
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:找到NXDOMAIN响应,立即返回
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新备选响应,确保总有一个可用的响应
|
||||
if resp.response != nil {
|
||||
if !hasBackup {
|
||||
// 第一次保存备选响应
|
||||
backupResponse = resp.response
|
||||
backupRtt = resp.rtt
|
||||
hasBackup = true
|
||||
} else {
|
||||
// 后续响应,优先保存用户配置的服务器响应作为备选
|
||||
if isUserUpstream {
|
||||
} else {
|
||||
// 更新备选响应,确保总有一个可用的响应
|
||||
if resp.response != nil {
|
||||
if !hasBackup {
|
||||
// 第一次保存备选响应
|
||||
backupResponse = resp.response
|
||||
backupRtt = resp.rtt
|
||||
hasBackup = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 即使响应不是成功或NXDOMAIN,也保存为最佳响应(如果还没有的话)
|
||||
// 确保总有一个响应返回给客户端
|
||||
if !hasBestResponse {
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("使用非成功响应作为最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt, "rcode", resp.response.Rcode)
|
||||
}
|
||||
} else {
|
||||
// 更新服务器统计信息(失败)
|
||||
s.updateServerStats(resp.server, false, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// 合并所有有效响应
|
||||
if len(validResponses) > 0 {
|
||||
bestResponse = mergeResponses(validResponses)
|
||||
if responseCount > 0 {
|
||||
bestRtt = totalRtt / time.Duration(responseCount)
|
||||
}
|
||||
hasBestResponse = true
|
||||
logger.Debug("合并所有响应返回", "domain", domain, "responseCount", len(validResponses))
|
||||
}
|
||||
|
||||
case "loadbalance":
|
||||
// 负载均衡模式 - 使用加权随机选择算法
|
||||
// 1. 尝试所有可用的服务器,直到找到一个能正常工作的
|
||||
@@ -1289,8 +1395,14 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}
|
||||
|
||||
default:
|
||||
// 默认使用并行请求模式 - 添加超时处理和快速响应返回
|
||||
// 默认使用并行请求模式 - 实现快速返回和超时机制
|
||||
responses := make(chan serverResponse, len(selectedUpstreamDNS))
|
||||
resultChan := make(chan struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
usedServer string
|
||||
usedDnssecServer string
|
||||
}, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// 向所有上游服务器并行发送请求
|
||||
@@ -1299,113 +1411,287 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
|
||||
// 发送请求并获取响应
|
||||
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(server))
|
||||
// 创建带有超时的resolver
|
||||
client := &dns.Client{
|
||||
Net: s.resolver.Net,
|
||||
UDPSize: s.resolver.UDPSize,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
|
||||
// 发送请求并获取响应,确保服务器地址包含端口号
|
||||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(server))
|
||||
responses <- serverResponse{response, rtt, server, err}
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// 等待所有请求完成
|
||||
// 处理响应的协程
|
||||
go func() {
|
||||
var fastestResponse *dns.Msg
|
||||
var fastestRtt time.Duration = defaultTimeout
|
||||
var fastestServer string
|
||||
var fastestDnssecServer string
|
||||
var fastestHasDnssec bool
|
||||
var validResponses []*dns.Msg
|
||||
|
||||
// 等待所有请求完成或超时
|
||||
timer := time.NewTimer(defaultTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
// 处理所有响应
|
||||
for {
|
||||
select {
|
||||
case resp, ok := <-responses:
|
||||
if !ok {
|
||||
// 所有响应都已处理
|
||||
goto doneProcessing
|
||||
}
|
||||
|
||||
if resp.error == nil && resp.response != nil {
|
||||
// 更新服务器统计信息
|
||||
s.updateServerStats(resp.server, true, resp.rtt)
|
||||
|
||||
// 检查是否包含DNSSEC记录
|
||||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||||
|
||||
// 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名
|
||||
// 但如果域名匹配不验证DNSSEC的模式,则跳过验证
|
||||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(resp.response)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
resp.response.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
} else if noDNSSEC {
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
resp.response.AuthenticatedData = false
|
||||
}
|
||||
|
||||
// 检查当前服务器是否是DNSSEC专用服务器
|
||||
dnssecServerForResponse := ""
|
||||
for _, dnssecServer := range dnssecServers {
|
||||
if dnssecServer == resp.server {
|
||||
dnssecServerForResponse = resp.server
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果响应成功或为NXDOMAIN
|
||||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||||
// 添加到有效响应列表,用于后续合并
|
||||
validResponses = append(validResponses, resp.response)
|
||||
|
||||
// 快速返回逻辑:找到第一个有效响应或更快的响应
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
// 优先选择带有DNSSEC的响应
|
||||
if containsDNSSEC {
|
||||
// 如果这是第一个DNSSEC响应,或者比当前最快的DNSSEC响应更快
|
||||
if !fastestHasDnssec || resp.rtt < fastestRtt {
|
||||
fastestResponse = resp.response
|
||||
fastestRtt = resp.rtt
|
||||
fastestServer = resp.server
|
||||
fastestDnssecServer = dnssecServerForResponse
|
||||
fastestHasDnssec = true
|
||||
|
||||
// 只对将要返回的响应进行DNSSEC验证
|
||||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
fastestResponse.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结果,快速返回
|
||||
resultChan <- struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
usedServer string
|
||||
usedDnssecServer string
|
||||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||||
}
|
||||
} else {
|
||||
// 非DNSSEC响应,只有在还没有找到DNSSEC响应且当前响应更快时才更新
|
||||
if !fastestHasDnssec && resp.rtt < fastestRtt {
|
||||
fastestResponse = resp.response
|
||||
fastestRtt = resp.rtt
|
||||
fastestServer = resp.server
|
||||
fastestDnssecServer = dnssecServerForResponse
|
||||
|
||||
// 检查是否包含DNSSEC记录
|
||||
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
|
||||
|
||||
// 只对将要返回的响应进行DNSSEC验证
|
||||
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
fastestResponse.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结果,快速返回
|
||||
resultChan <- struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
usedServer string
|
||||
usedDnssecServer string
|
||||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||||
}
|
||||
}
|
||||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||||
// NXDOMAIN响应,只有在还没有找到响应或当前响应更快时才更新
|
||||
if !fastestHasDnssec && resp.rtt < fastestRtt {
|
||||
fastestResponse = resp.response
|
||||
fastestRtt = resp.rtt
|
||||
fastestServer = resp.server
|
||||
fastestDnssecServer = dnssecServerForResponse
|
||||
|
||||
// 检查是否包含DNSSEC记录
|
||||
respContainsDNSSEC := s.hasDNSSECRecords(fastestResponse)
|
||||
|
||||
// 只对将要返回的响应进行DNSSEC验证
|
||||
if s.config.EnableDNSSEC && respContainsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(fastestResponse)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
fastestResponse.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结果,快速返回
|
||||
resultChan <- struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
usedServer string
|
||||
usedDnssecServer string
|
||||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 更新备选响应,确保总有一个可用的响应
|
||||
if resp.response != nil {
|
||||
if !hasBackup {
|
||||
// 第一次保存备选响应
|
||||
backupResponse = resp.response
|
||||
backupRtt = resp.rtt
|
||||
hasBackup = true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 更新服务器统计信息(失败)
|
||||
s.updateServerStats(resp.server, false, 0)
|
||||
}
|
||||
case <-timer.C:
|
||||
// 超时,停止等待更多响应
|
||||
goto doneProcessing
|
||||
}
|
||||
}
|
||||
|
||||
doneProcessing:
|
||||
// 合并所有有效响应,用于缓存
|
||||
if len(validResponses) > 1 {
|
||||
mergedResponse := mergeResponses(validResponses)
|
||||
if mergedResponse != nil {
|
||||
// 只在合并后的响应比最快响应更好时才使用
|
||||
mergedHasDnssec := s.hasDNSSECRecords(mergedResponse)
|
||||
if mergedHasDnssec && !fastestHasDnssec {
|
||||
// 合并后的响应有DNSSEC,而最快响应没有,使用合并后的响应
|
||||
fastestResponse = mergedResponse
|
||||
// 使用最快的Rtt作为合并响应的Rtt
|
||||
fastestHasDnssec = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果还没有发送结果,发送最快的响应
|
||||
if fastestResponse != nil {
|
||||
resultChan <- struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
usedServer string
|
||||
usedDnssecServer string
|
||||
}{fastestResponse, fastestRtt, fastestServer, fastestDnssecServer}
|
||||
}
|
||||
close(resultChan)
|
||||
}()
|
||||
|
||||
// 等待所有请求完成(不阻塞主流程)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(responses)
|
||||
}()
|
||||
|
||||
// 处理所有响应,实现快速响应返回
|
||||
for resp := range responses {
|
||||
if resp.error == nil && resp.response != nil {
|
||||
|
||||
// 检查是否包含DNSSEC记录
|
||||
containsDNSSEC := s.hasDNSSECRecords(resp.response)
|
||||
|
||||
// 如果启用了DNSSEC且响应包含DNSSEC记录,验证DNSSEC签名
|
||||
// 但如果域名匹配不验证DNSSEC的模式,则跳过验证
|
||||
if s.config.EnableDNSSEC && containsDNSSEC && !noDNSSEC {
|
||||
// 验证DNSSEC记录
|
||||
signatureValid := s.verifyDNSSEC(resp.response)
|
||||
|
||||
// 设置AD标志(Authenticated Data)
|
||||
resp.response.AuthenticatedData = signatureValid
|
||||
|
||||
if signatureValid {
|
||||
// 更新DNSSEC验证成功计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECSuccess++
|
||||
})
|
||||
} else {
|
||||
// 更新DNSSEC验证失败计数
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.DNSSECQueries++
|
||||
stats.DNSSECFailed++
|
||||
})
|
||||
}
|
||||
} else if noDNSSEC {
|
||||
// 对于不验证DNSSEC的域名,始终设置AD标志为false
|
||||
resp.response.AuthenticatedData = false
|
||||
}
|
||||
|
||||
// 如果响应成功或为NXDOMAIN,根据DNSSEC状态选择最佳响应
|
||||
if resp.response.Rcode == dns.RcodeSuccess || resp.response.Rcode == dns.RcodeNameError {
|
||||
if resp.response.Rcode == dns.RcodeSuccess {
|
||||
// 优先选择带有DNSSEC记录的响应
|
||||
if containsDNSSEC {
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
hasDNSSECResponse = true
|
||||
usedDNSServer = resp.server
|
||||
// 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer
|
||||
for _, dnssecServer := range dnssecServers {
|
||||
if dnssecServer == resp.server {
|
||||
usedDNSSECServer = resp.server
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Debug("找到带DNSSEC的最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:找到带DNSSEC的响应,立即返回
|
||||
continue
|
||||
} else if !hasBestResponse {
|
||||
// 没有带DNSSEC的响应时,保存第一个成功响应
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
// 如果当前使用的服务器是DNSSEC专用服务器,同时设置usedDNSSECServer
|
||||
for _, dnssecServer := range dnssecServers {
|
||||
if dnssecServer == resp.server {
|
||||
usedDNSSECServer = resp.server
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Debug("找到最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:第一次找到成功响应,立即返回
|
||||
continue
|
||||
}
|
||||
} else if resp.response.Rcode == dns.RcodeNameError {
|
||||
// 处理NXDOMAIN响应
|
||||
// 如果还没有最佳响应,或者最佳响应也是NXDOMAIN,优先选择更快的NXDOMAIN响应
|
||||
if !hasBestResponse || bestResponse.Rcode == dns.RcodeNameError {
|
||||
// 如果还没有最佳响应,或者当前响应更快,更新最佳响应
|
||||
if !hasBestResponse || resp.rtt < bestRtt {
|
||||
bestResponse = resp.response
|
||||
bestRtt = resp.rtt
|
||||
hasBestResponse = true
|
||||
usedDNSServer = resp.server
|
||||
logger.Debug("找到NXDOMAIN最佳响应", "domain", domain, "server", resp.server, "rtt", resp.rtt)
|
||||
// 快速返回:找到NXDOMAIN响应,立即返回
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
// 保存为备选响应
|
||||
if !hasBackup {
|
||||
backupResponse = resp.response
|
||||
backupRtt = resp.rtt
|
||||
hasBackup = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 等待结果或超时
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
// 快速返回结果
|
||||
bestResponse = result.response
|
||||
bestRtt = result.rtt
|
||||
usedDNSServer = result.usedServer
|
||||
usedDNSSECServer = result.usedDnssecServer
|
||||
hasBestResponse = true
|
||||
hasDNSSECResponse = s.hasDNSSECRecords(result.response)
|
||||
logger.Debug("快速返回DNS响应", "domain", domain, "server", result.usedServer, "rtt", result.rtt, "dnssec", hasDNSSECResponse)
|
||||
case <-time.After(defaultTimeout):
|
||||
// 超时,使用备选响应
|
||||
logger.Debug("并行请求超时", "domain", domain, "timeout", defaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1430,7 +1716,13 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
}, 1)
|
||||
|
||||
go func() {
|
||||
response, rtt, err := s.resolver.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
|
||||
// 创建带有超时的resolver
|
||||
client := &dns.Client{
|
||||
Net: s.resolver.Net,
|
||||
UDPSize: s.resolver.UDPSize,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
response, rtt, err := client.Exchange(r, normalizeDNSServerAddress(selectedDnssecServer))
|
||||
resultChan <- struct {
|
||||
response *dns.Msg
|
||||
rtt time.Duration
|
||||
@@ -1442,9 +1734,15 @@ func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg
|
||||
var rtt time.Duration
|
||||
var err error
|
||||
|
||||
// 直接获取结果,不使用上下文超时
|
||||
result := <-resultChan
|
||||
response, rtt, err = result.response, result.rtt, result.err
|
||||
// 使用超时获取结果
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
response, rtt, err = result.response, result.rtt, result.err
|
||||
case <-time.After(defaultTimeout):
|
||||
// 超时,不再等待
|
||||
logger.Debug("DNSSEC专用服务器请求超时", "domain", domain, "server", selectedDnssecServer, "timeout", defaultTimeout)
|
||||
return bestResponse, bestRtt, usedDNSServer, usedDNSSECServer
|
||||
}
|
||||
|
||||
if err == nil && response != nil {
|
||||
// 更新服务器统计信息
|
||||
@@ -2840,6 +3138,40 @@ func (s *Server) startCpuUsageMonitor() {
|
||||
}
|
||||
}
|
||||
|
||||
// startIPGeolocationCacheCleanup 启动IP地理位置缓存清理协程
|
||||
func (s *Server) startIPGeolocationCacheCleanup() {
|
||||
ticker := time.NewTicker(time.Hour) // 每小时清理一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.cleanupExpiredIPGeolocationCache()
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredIPGeolocationCache 清理过期的IP地理位置缓存
|
||||
func (s *Server) cleanupExpiredIPGeolocationCache() {
|
||||
now := time.Now()
|
||||
s.ipGeolocationCacheMutex.Lock()
|
||||
defer s.ipGeolocationCacheMutex.Unlock()
|
||||
|
||||
var deletedCount int
|
||||
for ip, geo := range s.ipGeolocationCache {
|
||||
if now.After(geo.Expiry) {
|
||||
delete(s.ipGeolocationCache, ip)
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
logger.Info("清理过期的IP地理位置缓存", "deleted", deletedCount, "remaining", len(s.ipGeolocationCache))
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemCpuUsage 获取系统CPU使用率
|
||||
func getSystemCpuUsage(prevIdle, prevTotal *uint64) (float64, error) {
|
||||
// 读取/proc/stat文件获取CPU统计信息
|
||||
|
||||
Reference in New Issue
Block a user