实现缓存功能

This commit is contained in:
Alex Yang
2025-11-30 20:20:59 +08:00
parent f623654151
commit c16e147931
33 changed files with 276 additions and 445090 deletions

View File

@@ -56,6 +56,7 @@ type QueryLog struct {
Result string // 查询结果allowed, blocked, error
BlockRule string // 屏蔽规则(如果被屏蔽)
BlockType string // 屏蔽类型(如果被屏蔽)
FromCache bool // 是否来自缓存
}
// StatsData 用于持久化的统计数据结构
@@ -107,6 +108,9 @@ type Server struct {
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
ipGeolocationCacheTTL time.Duration // 缓存有效期
// DNS查询缓存
dnsCache *DNSCache // DNS响应缓存
}
// Stats DNS服务器统计信息
@@ -126,6 +130,10 @@ type Stats struct {
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值分钟
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
@@ -161,6 +169,8 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
// IP地理位置缓存初始化
ipGeolocationCache: make(map[string]*IPGeolocation),
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
// DNS查询缓存初始化
dnsCache: NewDNSCache(cacheTTL),
}
// 加载已保存的统计数据
@@ -291,6 +301,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
@@ -299,6 +310,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
@@ -325,7 +337,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "")
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false)
return
}
@@ -342,7 +354,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
return
}
@@ -364,12 +376,40 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType)
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false)
return
}
// 转发到上游DNS服务器
s.forwardDNSRequest(w, r, domain)
// 检查缓存中是否有响应(增强版缓存查询)
if cachedResponse, found := s.dnsCache.Get(r.Question[0].Name, qType); found {
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
w.WriteMsg(cachedResponseCopy)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志 - 标记为缓存
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true)
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType)
return
}
// 缓存未命中转发到上游DNS服务器
response, _ := s.forwardDNSRequestWithCache(r, domain)
if response != nil {
// 写入响应给客户端
w.WriteMsg(response)
}
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
@@ -379,8 +419,18 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
s.dnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL)
}
// 添加查询日志 - 标记为实时
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
}
// handleHostsResponse 处理hosts文件匹配的响应
@@ -484,7 +534,8 @@ func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration) {
// 尝试所有上游DNS服务器
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
@@ -492,7 +543,6 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
// 设置递归可用标志
response.RecursionAvailable = true
w.WriteMsg(response)
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
// 记录解析域名统计
@@ -501,7 +551,7 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return
return response, rtt
}
}
@@ -510,12 +560,18 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(response)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
@@ -599,7 +655,7 @@ func (s *Server) updateStats(update func(*Stats)) {
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string) {
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache bool) {
// 获取IP地理位置
location := s.getIpGeolocation(clientIP)
@@ -614,6 +670,7 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
}
// 添加到日志列表