实现缓存功能

This commit is contained in:
Alex Yang
2025-11-30 20:20:59 +08:00
parent f623654151
commit c16e147931
33 changed files with 276 additions and 445090 deletions

127
dns/cache.go Normal file
View File

@@ -0,0 +1,127 @@
package dns
import (
"sync"
"time"
"github.com/miekg/dns"
)
// DNSCacheItem 表示缓存中的DNS响应项
type DNSCacheItem struct {
Response *dns.Msg // DNS响应消息
Expiry time.Time // 过期时间
}
// DNSCache DNS缓存结构
type DNSCache struct {
cache map[string]*DNSCacheItem // 缓存映射表
mutex sync.RWMutex // 读写锁,保护缓存
defaultTTL time.Duration // 默认缓存TTL
}
// NewDNSCache 创建新的DNS缓存实例
func NewDNSCache(defaultTTL time.Duration) *DNSCache {
cache := &DNSCache{
cache: make(map[string]*DNSCacheItem),
defaultTTL: defaultTTL,
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// cacheKey 生成缓存键
func cacheKey(qName string, qType uint16) string {
return qName + "|" + dns.TypeToString[qType]
}
// Set 设置缓存项
func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) {
if ttl <= 0 {
ttl = c.defaultTTL
}
key := cacheKey(qName, qType)
item := &DNSCacheItem{
Response: response.Copy(), // 复制响应以避免外部修改
Expiry: time.Now().Add(ttl),
}
c.mutex.Lock()
c.cache[key] = item
c.mutex.Unlock()
}
// Get 获取缓存项
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
key := cacheKey(qName, qType)
c.mutex.RLock()
item, found := c.cache[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
// 检查是否过期
if time.Now().After(item.Expiry) {
c.mutex.RUnlock()
// 过期了,删除缓存项(在写锁中)
c.delete(key)
return nil, false
}
// 返回缓存的响应副本
response := item.Response.Copy()
c.mutex.RUnlock()
return response, true
}
// delete 删除缓存项
func (c *DNSCache) delete(key string) {
c.mutex.Lock()
delete(c.cache, key)
c.mutex.Unlock()
}
// Clear 清空缓存
func (c *DNSCache) Clear() {
c.mutex.Lock()
c.cache = make(map[string]*DNSCacheItem)
c.mutex.Unlock()
}
// Size 获取缓存大小
func (c *DNSCache) Size() int {
c.mutex.RLock()
defer c.mutex.RUnlock()
return len(c.cache)
}
// startCleanupLoop 启动定期清理过期缓存的协程
func (c *DNSCache) startCleanupLoop() {
ticker := time.NewTicker(time.Minute * 5) // 每5分钟清理一次
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期的缓存项
func (c *DNSCache) cleanupExpired() {
now := time.Now()
c.mutex.Lock()
defer c.mutex.Unlock()
for key, item := range c.cache {
if now.After(item.Expiry) {
delete(c.cache, key)
}
}
}

View File

@@ -56,6 +56,7 @@ type QueryLog struct {
Result string // 查询结果allowed, blocked, error
BlockRule string // 屏蔽规则(如果被屏蔽)
BlockType string // 屏蔽类型(如果被屏蔽)
FromCache bool // 是否来自缓存
}
// StatsData 用于持久化的统计数据结构
@@ -107,6 +108,9 @@ type Server struct {
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
ipGeolocationCacheTTL time.Duration // 缓存有效期
// DNS查询缓存
dnsCache *DNSCache // DNS响应缓存
}
// Stats DNS服务器统计信息
@@ -126,6 +130,10 @@ type Stats struct {
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
// 从配置中读取DNS缓存TTL值分钟
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
server := &Server{
config: config,
shieldConfig: shieldConfig,
@@ -161,6 +169,8 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
// IP地理位置缓存初始化
ipGeolocationCache: make(map[string]*IPGeolocation),
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
// DNS查询缓存初始化
dnsCache: NewDNSCache(cacheTTL),
}
// 加载已保存的统计数据
@@ -291,6 +301,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
// 获取查询域名和类型
var domain string
var queryType string
var qType uint16
if len(r.Question) > 0 {
domain = r.Question[0].Name
// 移除末尾的点
@@ -299,6 +310,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
// 获取查询类型
queryType = dns.TypeToString[r.Question[0].Qtype]
qType = r.Question[0].Qtype
// 更新查询类型统计
s.updateStats(func(stats *Stats) {
stats.QueryTypes[queryType]++
@@ -325,7 +337,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "")
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false)
return
}
@@ -342,7 +354,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
return
}
@@ -364,12 +376,40 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType)
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false)
return
}
// 转发到上游DNS服务器
s.forwardDNSRequest(w, r, domain)
// 检查缓存中是否有响应(增强版缓存查询)
if cachedResponse, found := s.dnsCache.Get(r.Question[0].Name, qType); found {
// 缓存命中,直接返回缓存的响应
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
cachedResponseCopy.Compress = true
w.WriteMsg(cachedResponseCopy)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
stats.TotalResponseTime += responseTime
if stats.Queries > 0 {
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
}
})
// 添加查询日志 - 标记为缓存
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true)
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType)
return
}
// 缓存未命中转发到上游DNS服务器
response, _ := s.forwardDNSRequestWithCache(r, domain)
if response != nil {
// 写入响应给客户端
w.WriteMsg(response)
}
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
s.updateStats(func(stats *Stats) {
@@ -379,8 +419,18 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
}
})
// 添加查询日志
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
// 如果响应成功,缓存结果(增强版缓存存储)
if response != nil && response.Rcode == dns.RcodeSuccess {
// 创建响应副本以避免后续修改影响缓存
responseCopy := response.Copy()
// 设置合理的TTL不超过默认的30分钟
defaultCacheTTL := 30 * time.Minute
s.dnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL)
}
// 添加查询日志 - 标记为实时
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
}
// handleHostsResponse 处理hosts文件匹配的响应
@@ -484,7 +534,8 @@ func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration) {
// 尝试所有上游DNS服务器
for _, upstream := range s.config.UpstreamDNS {
response, rtt, err := s.resolver.Exchange(r, upstream)
@@ -492,7 +543,6 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
// 设置递归可用标志
response.RecursionAvailable = true
w.WriteMsg(response)
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
// 记录解析域名统计
@@ -501,7 +551,7 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
s.updateStats(func(stats *Stats) {
stats.Allowed++
})
return
return response, rtt
}
}
@@ -510,12 +560,18 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
response.SetReply(r)
response.RecursionAvailable = true
response.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(response)
logger.Error("DNS查询失败", "domain", domain)
s.updateStats(func(stats *Stats) {
stats.Errors++
})
return response, 0
}
// forwardDNSRequest 转发DNS请求到上游服务器
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
response, _ := s.forwardDNSRequestWithCache(r, domain)
w.WriteMsg(response)
}
// updateBlockedDomainStats 更新被屏蔽域名统计
@@ -599,7 +655,7 @@ func (s *Server) updateStats(update func(*Stats)) {
}
// addQueryLog 添加查询日志
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string) {
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache bool) {
// 获取IP地理位置
location := s.getIpGeolocation(clientIP)
@@ -614,6 +670,7 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
Result: result,
BlockRule: blockRule,
BlockType: blockType,
FromCache: fromCache,
}
// 添加到日志列表