实现缓存功能
This commit is contained in:
@@ -56,6 +56,7 @@ type QueryLog struct {
|
||||
Result string // 查询结果(allowed, blocked, error)
|
||||
BlockRule string // 屏蔽规则(如果被屏蔽)
|
||||
BlockType string // 屏蔽类型(如果被屏蔽)
|
||||
FromCache bool // 是否来自缓存
|
||||
}
|
||||
|
||||
// StatsData 用于持久化的统计数据结构
|
||||
@@ -107,6 +108,9 @@ type Server struct {
|
||||
ipGeolocationCache map[string]*IPGeolocation // IP地址到地理位置的映射
|
||||
ipGeolocationCacheMutex sync.RWMutex // 保护IP地理位置缓存的互斥锁
|
||||
ipGeolocationCacheTTL time.Duration // 缓存有效期
|
||||
|
||||
// DNS查询缓存
|
||||
dnsCache *DNSCache // DNS响应缓存
|
||||
}
|
||||
|
||||
// Stats DNS服务器统计信息
|
||||
@@ -126,6 +130,10 @@ type Stats struct {
|
||||
// NewServer 创建DNS服务器实例
|
||||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// 从配置中读取DNS缓存TTL值(分钟)
|
||||
cacheTTL := time.Duration(config.CacheTTL) * time.Minute
|
||||
|
||||
server := &Server{
|
||||
config: config,
|
||||
shieldConfig: shieldConfig,
|
||||
@@ -161,6 +169,8 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
|
||||
// IP地理位置缓存初始化
|
||||
ipGeolocationCache: make(map[string]*IPGeolocation),
|
||||
ipGeolocationCacheTTL: 24 * time.Hour, // 缓存有效期24小时
|
||||
// DNS查询缓存初始化
|
||||
dnsCache: NewDNSCache(cacheTTL),
|
||||
}
|
||||
|
||||
// 加载已保存的统计数据
|
||||
@@ -291,6 +301,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// 获取查询域名和类型
|
||||
var domain string
|
||||
var queryType string
|
||||
var qType uint16
|
||||
if len(r.Question) > 0 {
|
||||
domain = r.Question[0].Name
|
||||
// 移除末尾的点
|
||||
@@ -299,6 +310,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
// 获取查询类型
|
||||
queryType = dns.TypeToString[r.Question[0].Qtype]
|
||||
qType = r.Question[0].Qtype
|
||||
// 更新查询类型统计
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.QueryTypes[queryType]++
|
||||
@@ -325,7 +337,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "")
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "error", "", "", false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -342,7 +354,7 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -364,12 +376,40 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType)
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "blocked", blockRule, blockType, false)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发到上游DNS服务器
|
||||
s.forwardDNSRequest(w, r, domain)
|
||||
// 检查缓存中是否有响应(增强版缓存查询)
|
||||
if cachedResponse, found := s.dnsCache.Get(r.Question[0].Name, qType); found {
|
||||
// 缓存命中,直接返回缓存的响应
|
||||
cachedResponseCopy := cachedResponse.Copy() // 创建响应副本避免并发修改问题
|
||||
cachedResponseCopy.Id = r.Id // 更新ID以匹配请求
|
||||
cachedResponseCopy.Compress = true
|
||||
w.WriteMsg(cachedResponseCopy)
|
||||
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.TotalResponseTime += responseTime
|
||||
if stats.Queries > 0 {
|
||||
stats.AvgResponseTime = float64(stats.TotalResponseTime) / float64(stats.Queries)
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志 - 标记为缓存
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", true)
|
||||
logger.Debug("从缓存返回DNS响应", "domain", domain, "type", queryType)
|
||||
return
|
||||
}
|
||||
|
||||
// 缓存未命中,转发到上游DNS服务器
|
||||
response, _ := s.forwardDNSRequestWithCache(r, domain)
|
||||
if response != nil {
|
||||
// 写入响应给客户端
|
||||
w.WriteMsg(response)
|
||||
}
|
||||
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
s.updateStats(func(stats *Stats) {
|
||||
@@ -379,8 +419,18 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
})
|
||||
|
||||
// 添加查询日志
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "")
|
||||
// 如果响应成功,缓存结果(增强版缓存存储)
|
||||
if response != nil && response.Rcode == dns.RcodeSuccess {
|
||||
// 创建响应副本以避免后续修改影响缓存
|
||||
responseCopy := response.Copy()
|
||||
// 设置合理的TTL,不超过默认的30分钟
|
||||
defaultCacheTTL := 30 * time.Minute
|
||||
s.dnsCache.Set(r.Question[0].Name, qType, responseCopy, defaultCacheTTL)
|
||||
logger.Debug("DNS响应已缓存", "domain", domain, "type", queryType, "ttl", defaultCacheTTL)
|
||||
}
|
||||
|
||||
// 添加查询日志 - 标记为实时
|
||||
s.addQueryLog(sourceIP, domain, queryType, responseTime, "allowed", "", "", false)
|
||||
}
|
||||
|
||||
// handleHostsResponse 处理hosts文件匹配的响应
|
||||
@@ -484,7 +534,8 @@ func (s *Server) handleBlockedResponse(w dns.ResponseWriter, r *dns.Msg, domain
|
||||
}
|
||||
|
||||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
// forwardDNSRequestWithCache 转发DNS请求到上游服务器并返回响应
|
||||
func (s *Server) forwardDNSRequestWithCache(r *dns.Msg, domain string) (*dns.Msg, time.Duration) {
|
||||
// 尝试所有上游DNS服务器
|
||||
for _, upstream := range s.config.UpstreamDNS {
|
||||
response, rtt, err := s.resolver.Exchange(r, upstream)
|
||||
@@ -492,7 +543,6 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
// 设置递归可用标志
|
||||
response.RecursionAvailable = true
|
||||
|
||||
w.WriteMsg(response)
|
||||
logger.Debug("DNS查询成功", "domain", domain, "rtt", rtt, "server", upstream)
|
||||
|
||||
// 记录解析域名统计
|
||||
@@ -501,7 +551,7 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Allowed++
|
||||
})
|
||||
return
|
||||
return response, rtt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,12 +560,18 @@ func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain stri
|
||||
response.SetReply(r)
|
||||
response.RecursionAvailable = true
|
||||
response.SetRcode(r, dns.RcodeServerFailure)
|
||||
w.WriteMsg(response)
|
||||
|
||||
logger.Error("DNS查询失败", "domain", domain)
|
||||
s.updateStats(func(stats *Stats) {
|
||||
stats.Errors++
|
||||
})
|
||||
return response, 0
|
||||
}
|
||||
|
||||
// forwardDNSRequest 转发DNS请求到上游服务器
|
||||
func (s *Server) forwardDNSRequest(w dns.ResponseWriter, r *dns.Msg, domain string) {
|
||||
response, _ := s.forwardDNSRequestWithCache(r, domain)
|
||||
w.WriteMsg(response)
|
||||
}
|
||||
|
||||
// updateBlockedDomainStats 更新被屏蔽域名统计
|
||||
@@ -599,7 +655,7 @@ func (s *Server) updateStats(update func(*Stats)) {
|
||||
}
|
||||
|
||||
// addQueryLog 添加查询日志
|
||||
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string) {
|
||||
func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime int64, result, blockRule, blockType string, fromCache bool) {
|
||||
// 获取IP地理位置
|
||||
location := s.getIpGeolocation(clientIP)
|
||||
|
||||
@@ -614,6 +670,7 @@ func (s *Server) addQueryLog(clientIP, domain, queryType string, responseTime in
|
||||
Result: result,
|
||||
BlockRule: blockRule,
|
||||
BlockType: blockType,
|
||||
FromCache: fromCache,
|
||||
}
|
||||
|
||||
// 添加到日志列表
|
||||
|
||||
Reference in New Issue
Block a user