Files
dns-server/dns/cache.go
2026-01-25 16:13:52 +08:00

839 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"encoding/json"
"io/ioutil"
"math"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
// DNSCacheItem 表示缓存中的DNS响应项
type DNSCacheItem struct {
Response *dns.Msg // DNS响应消息
Expiry time.Time // 过期时间
HasDNSSEC bool // 是否包含DNSSEC记录
Size int // 缓存项大小(字节)
}
// SerializableDNSCacheItem 用于JSON序列化的缓存项结构
type SerializableDNSCacheItem struct {
ResponseBytes []byte `json:"responseBytes"` // 二进制DNS响应
Expiry int64 `json:"expiry"` // 过期时间(纳秒)
HasDNSSEC bool `json:"hasDNSSEC"` // 是否包含DNSSEC记录
Size int `json:"size"` // 缓存项大小(字节)
}
// SerializableDNSCache 可序列化的缓存结构
type SerializableDNSCache struct {
Items map[string]*SerializableDNSCacheItem `json:"items"` // 缓存项
TTL int64 `json:"ttl"` // 默认TTL纳秒
MaxSize int `json:"maxSize"` // 最大缓存大小
CacheMode string `json:"cacheMode"` // 缓存模式
CacheFilePath string `json:"cacheFilePath"` // 缓存文件路径
}
// DNSCache DNS缓存结构
type DNSCache struct {
cache map[string]*LRUNode // 缓存映射表,直接存储链表节点
mutex sync.RWMutex // 读写锁,保护缓存
ttl time.Duration // 默认缓存TTL
maxSize int // 最大缓存条目数
cacheSize int64 // 当前缓存大小(字节)
maxCacheSize int64 // 最大缓存大小(字节)
cacheMode string // 缓存模式
cacheFilePath string // 缓存文件路径
saveInterval time.Duration // 保存间隔
saveMutex sync.Mutex // 保存互斥锁
maxCacheTTL time.Duration // 最大缓存TTL
minCacheTTL time.Duration // 最小缓存TTL
saveStopCh chan struct{} // 保存循环停止通道
saveRunning bool // 保存循环是否运行
saveLoopMutex sync.Mutex // 保护保存循环状态的互斥锁
// 双向链表头和尾指针用于LRU淘汰
head *LRUNode // 头指针,指向最久未使用的节点
tail *LRUNode // 尾指针,指向最近使用的节点
// 缓存变化跟踪,用于智能保存
changeCount int // 缓存变化次数
lastSaveCacheSize int64 // 上次保存时的缓存大小
lastSaveItemCount int // 上次保存时的缓存项数量
lastSaveTime time.Time // 上次保存时间
minSaveInterval time.Duration // 最小保存间隔,避免过于频繁的保存
}
// LRUNode 双向链表节点用于LRU缓存
type LRUNode struct {
key string
value *DNSCacheItem
prev *LRUNode
next *LRUNode
}
// NewDNSCache 创建新的DNS缓存实例
func NewDNSCache(defaultTTL time.Duration, cacheMode string, cacheSizeMB int, cacheFilePath string, saveInterval time.Duration, maxCacheTTL, minCacheTTL time.Duration) *DNSCache {
// 计算最大缓存大小(字节)
maxCacheSize := int64(cacheSizeMB) * 1024 * 1024
cache := &DNSCache{
cache: make(map[string]*LRUNode),
ttl: defaultTTL,
maxSize: 10000, // 默认最大缓存10000条记录
cacheSize: 0,
maxCacheSize: maxCacheSize,
cacheMode: cacheMode,
cacheFilePath: cacheFilePath,
saveInterval: saveInterval,
maxCacheTTL: maxCacheTTL,
minCacheTTL: minCacheTTL,
saveStopCh: make(chan struct{}),
saveRunning: false,
head: nil,
tail: nil,
changeCount: 0,
lastSaveCacheSize: 0,
lastSaveItemCount: 0,
lastSaveTime: time.Now(),
minSaveInterval: 30 * time.Second, // 最小保存间隔为30秒避免过于频繁的保存
}
// 加载现有缓存(如果存在)
if cacheMode == "file" {
cache.LoadFromFile()
}
// 启动缓存清理协程
go cache.startCleanupLoop()
// 启动定期保存协程(如果是文件缓存)
if cacheMode == "file" {
go cache.startSaveLoop()
}
return cache
}
// addNodeToTail 将节点添加到链表尾部(表示最近使用)
func (c *DNSCache) addNodeToTail(node *LRUNode) {
if c.tail == nil {
// 链表为空
c.head = node
c.tail = node
} else {
// 添加到尾部
node.prev = c.tail
c.tail.next = node
c.tail = node
}
}
// removeNode 从链表中移除指定节点
func (c *DNSCache) removeNode(node *LRUNode) {
if node.prev != nil {
node.prev.next = node.next
} else {
// 移除的是头节点
c.head = node.next
}
if node.next != nil {
node.next.prev = node.prev
} else {
// 移除的是尾节点
c.tail = node.prev
}
// 清空节点的前后指针
node.prev = nil
node.next = nil
}
// moveNodeToTail 将节点移动到链表尾部(表示最近使用)
func (c *DNSCache) moveNodeToTail(node *LRUNode) {
// 如果已经是尾节点,不需要移动
if node == c.tail {
return
}
// 从链表中移除节点
c.removeNode(node)
// 重新添加到尾部
c.addNodeToTail(node)
}
// cacheKey 生成缓存键
func cacheKey(qName string, qType uint16) string {
return qName + "|" + dns.TypeToString[qType]
}
// calculateItemSize 计算缓存项大小
func calculateItemSize(item *DNSCacheItem) int {
// 使用更高效的方式估算缓存项大小
// 避免使用json.Marshal和rr.String()因为它们在高频调用时会消耗大量CPU资源
size := 0
// 估算Response大小
if item.Response != nil {
// 粗略估算DNS消息大小
// 头部大小约12字节
size += 12
// 问题部分
for _, q := range item.Response.Question {
size += len(q.Name) + 4 // 域名长度 + 类型(2) + 类(2)
}
// 高效估算资源记录大小避免调用rr.String()
estimateRRSize := func(rr dns.RR) int {
rrSize := len(rr.Header().Name) + 10 // 域名 + 类型(2) + 类(2) + TTL(4) + 长度(2)
switch rr.Header().Rrtype {
case dns.TypeA:
rrSize += 4 // IPv4地址
case dns.TypeAAAA:
rrSize += 16 // IPv6地址
case dns.TypeCNAME, dns.TypePTR, dns.TypeNS:
// 对于CNAME、PTR、NS记录需要估算目标域名长度
if cname, ok := rr.(*dns.CNAME); ok {
rrSize += len(cname.Target)
} else if ptr, ok := rr.(*dns.PTR); ok {
rrSize += len(ptr.Ptr)
} else if ns, ok := rr.(*dns.NS); ok {
rrSize += len(ns.Ns)
} else {
// 默认估算
rrSize += 30
}
case dns.TypeMX:
// MX记录优先级(2) + 目标域名
rrSize += 2 + 30 // 默认30字节目标域名
case dns.TypeTXT:
// TXT记录文本长度
rrSize += 50 // 默认50字节文本
case dns.TypeSRV:
// SRV记录优先级(2) + 权重(2) + 端口(2) + 目标域名
rrSize += 6 + 30 // 默认30字节目标域名
case dns.TypeSOA:
// SOA记录主NS + 管理员邮箱 + 序列号(4) + 刷新时间(4) + 重试时间(4) + 过期时间(4) + 最小TTL(4)
rrSize += 100 // 默认100字节
default:
// 其他类型记录,使用默认估算
rrSize += 50
}
return rrSize
}
// 回答部分
for _, rr := range item.Response.Answer {
size += estimateRRSize(rr)
}
// 授权部分
for _, rr := range item.Response.Ns {
size += estimateRRSize(rr)
}
// 附加部分
for _, rr := range item.Response.Extra {
if rr.Header().Rrtype == dns.TypeOPT {
// OPT记录大小约为40字节EDNS0
size += 40
} else {
size += estimateRRSize(rr)
}
}
}
// 其他字段大小
size += 8 // Expiry
size += 1 // HasDNSSEC
return size
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func hasDNSSECRecords(response *dns.Msg) bool {
// 直接在循环中检查RR类型避免创建匿名函数的开销
// 检查回答部分
for _, rr := range response.Answer {
switch rr.(type) {
case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3:
return true
}
}
// 检查授权部分
for _, rr := range response.Ns {
switch rr.(type) {
case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3:
return true
}
}
// 检查附加部分
for _, rr := range response.Extra {
switch rr.(type) {
case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3:
return true
}
}
return false
}
// Set 设置缓存项
func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) {
// 设置默认TTL
if ttl <= 0 {
ttl = c.ttl
}
// 应用maxCacheTTL和minCacheTTL约束
if c.maxCacheTTL > 0 && ttl > c.maxCacheTTL {
ttl = c.maxCacheTTL
}
if c.minCacheTTL > 0 && ttl < c.minCacheTTL {
ttl = c.minCacheTTL
}
key := cacheKey(qName, qType)
item := &DNSCacheItem{
Response: response.Copy(), // 复制响应以避免外部修改
Expiry: time.Now().Add(ttl),
HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志
}
// 计算缓存项大小
item.Size = calculateItemSize(item)
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果条目已存在,先从链表和缓存中移除,并更新缓存大小
if existingNode, found := c.cache[key]; found {
c.cacheSize -= int64(existingNode.value.Size)
c.removeNode(existingNode)
delete(c.cache, key)
}
// 创建新的链表节点并添加到尾部
newNode := &LRUNode{
key: key,
value: item,
}
c.addNodeToTail(newNode)
c.cache[key] = newNode
c.cacheSize += int64(item.Size)
// 检查是否超过最大条目数限制,如果超过则移除最久未使用的条目
if len(c.cache) > c.maxSize {
// 最久未使用的条目是链表的头节点
if c.head != nil {
c.cacheSize -= int64(c.head.value.Size)
oldestKey := c.head.key
// 从缓存和链表中移除头节点
delete(c.cache, oldestKey)
c.removeNode(c.head)
}
}
// 检查是否超过最大缓存大小,如果超过则继续移除最久未使用的条目
for c.cacheSize > c.maxCacheSize && c.head != nil {
c.cacheSize -= int64(c.head.value.Size)
oldestKey := c.head.key
delete(c.cache, oldestKey)
c.removeNode(c.head)
}
// 更新缓存变化计数
c.changeCount++
}
// Get 获取缓存项
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
key := cacheKey(qName, qType)
// 首先使用读锁检查缓存项是否存在和是否过期
c.mutex.RLock()
node, found := c.cache[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
// 检查是否过期
if time.Now().After(node.value.Expiry) {
c.mutex.RUnlock()
// 需要删除过期条目,使用写锁
c.mutex.Lock()
// 再次检查,防止在读写锁切换期间被其他协程处理
if node, stillExists := c.cache[key]; stillExists && time.Now().After(node.value.Expiry) {
delete(c.cache, key)
c.removeNode(node)
}
c.mutex.Unlock()
return nil, false
}
// 返回前释放读锁,避免长时间持有锁
response := node.value.Response.Copy()
c.mutex.RUnlock()
// 标记为最近使用需要修改链表,使用写锁
c.mutex.Lock()
// 再次检查节点是否存在,防止在读写锁切换期间被删除
if node, stillExists := c.cache[key]; stillExists {
c.moveNodeToTail(node)
}
c.mutex.Unlock()
return response, true
}
// delete 删除缓存项
func (c *DNSCache) delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 从缓存和链表中删除
if node, found := c.cache[key]; found {
delete(c.cache, key)
c.removeNode(node)
}
}
// Clear 清空缓存
func (c *DNSCache) Clear() {
c.mutex.Lock()
c.cache = make(map[string]*LRUNode)
// 重置链表指针
c.head = nil
c.tail = nil
c.mutex.Unlock()
}
// Size 获取缓存大小
func (c *DNSCache) Size() int {
c.mutex.RLock()
defer c.mutex.RUnlock()
return len(c.cache)
}
// startCleanupLoop 启动定期清理过期缓存的协程
func (c *DNSCache) startCleanupLoop() {
// 初始清理间隔为1分钟
cleanupInterval := time.Minute * 1
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for range ticker.C {
cleanupInterval = c.cleanupExpired()
// 调整下次清理间隔范围15秒到5分钟
if cleanupInterval < 15*time.Second {
cleanupInterval = 15 * time.Second
} else if cleanupInterval > 5*time.Minute {
cleanupInterval = 5 * time.Minute
}
// 更新清理间隔
ticker.Reset(cleanupInterval)
}
}
// startSaveLoop 启动定期保存缓存的协程
func (c *DNSCache) startSaveLoop() {
c.saveLoopMutex.Lock()
// 如果已经在运行,直接返回
if c.saveRunning {
c.saveLoopMutex.Unlock()
return
}
// 重置停止通道
c.saveStopCh = make(chan struct{})
c.saveRunning = true
c.saveLoopMutex.Unlock()
go func() {
ticker := time.NewTicker(c.saveInterval) // 根据配置的间隔保存
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 检查缓存模式如果不是file模式则不保存
c.mutex.RLock()
mode := c.cacheMode
c.mutex.RUnlock()
if mode == "file" {
c.SaveToFile()
}
case <-c.saveStopCh:
// 停止保存循环
c.saveLoopMutex.Lock()
c.saveRunning = false
c.saveLoopMutex.Unlock()
return
}
}
}()
}
// saveCacheToFile 保存缓存到文件的底层实现,不检查缓存模式
func (c *DNSCache) saveCacheToFile() {
c.saveMutex.Lock()
defer c.saveMutex.Unlock()
// 智能保存策略
// 1. 如果缓存变化次数少于10次跳过保存
// 2. 如果距离上次保存时间不足最小保存间隔,跳过保存
c.mutex.RLock()
changeCount := c.changeCount
lastSaveTime := c.lastSaveTime
lastSaveCacheSize := c.lastSaveCacheSize
lastSaveItemCount := c.lastSaveItemCount
currentCacheSize := c.cacheSize
currentItemCount := len(c.cache)
c.mutex.RUnlock()
// 检查是否需要保存
if changeCount < 10 {
return
}
if time.Since(lastSaveTime) < c.minSaveInterval {
return
}
if currentItemCount > 0 {
cacheSizeChange := float64(currentCacheSize-lastSaveCacheSize) / float64(lastSaveCacheSize+1) // +1避免除以零
itemCountChange := float64(currentItemCount-lastSaveItemCount) / float64(lastSaveItemCount+1) // +1避免除以零
if math.Abs(cacheSizeChange) < 0.1 && math.Abs(itemCountChange) < 0.1 {
return
}
}
// 开始保存缓存
c.mutex.RLock()
// 收集有效的缓存项
validItems := make(map[string]*SerializableDNSCacheItem)
now := time.Now()
for key, node := range c.cache {
// 只保存未过期的缓存项
if now.Before(node.value.Expiry) {
// 序列化DNS响应为二进制
responseBytes, err := node.value.Response.Pack()
if err != nil {
continue // 跳过无法序列化的响应
}
// 创建可序列化的缓存项
validItems[key] = &SerializableDNSCacheItem{
ResponseBytes: responseBytes,
Expiry: node.value.Expiry.UnixNano(),
HasDNSSEC: node.value.HasDNSSEC,
Size: node.value.Size,
}
}
}
// 创建可序列化的缓存结构
serializableCache := &SerializableDNSCache{
Items: validItems,
TTL: int64(c.ttl),
MaxSize: c.maxSize,
CacheMode: c.cacheMode,
CacheFilePath: c.cacheFilePath,
}
c.mutex.RUnlock()
// 序列化到JSON
data, err := json.MarshalIndent(serializableCache, "", " ")
if err != nil {
return
}
// 确保目录存在
os.MkdirAll(cacheDir(), 0755)
// 保存到文件
err = ioutil.WriteFile(c.cacheFilePath, data, 0644)
if err != nil {
return
}
// 更新保存状态
c.mutex.Lock()
c.changeCount = 0
c.lastSaveTime = time.Now()
c.lastSaveCacheSize = currentCacheSize
c.lastSaveItemCount = currentItemCount
c.mutex.Unlock()
}
// SaveToFile 保存缓存到文件
func (c *DNSCache) SaveToFile() {
// 检查缓存模式如果不是file模式直接返回
c.mutex.RLock()
mode := c.cacheMode
c.mutex.RUnlock()
if mode != "file" {
return
}
// 调用底层保存逻辑
c.saveCacheToFile()
}
// LoadFromFile 从文件加载缓存
func (c *DNSCache) LoadFromFile() {
c.mutex.Lock()
defer c.mutex.Unlock()
// 检查文件是否存在
if _, err := os.Stat(c.cacheFilePath); os.IsNotExist(err) {
return // 文件不存在,跳过加载
}
// 读取文件内容
data, err := ioutil.ReadFile(c.cacheFilePath)
if err != nil {
return
}
// 反序列化JSON
var serializableCache SerializableDNSCache
err = json.Unmarshal(data, &serializableCache)
if err != nil {
return
}
// 加载缓存项
now := time.Now()
for key, serializableItem := range serializableCache.Items {
// 转换过期时间
expiry := time.Unix(0, serializableItem.Expiry)
// 只加载未过期的缓存项
if now.Before(expiry) {
// 反序列化二进制DNS响应
response := &dns.Msg{}
err := response.Unpack(serializableItem.ResponseBytes)
if err != nil {
continue // 跳过无法反序列化的响应
}
// 创建缓存项
item := &DNSCacheItem{
Response: response,
Expiry: expiry,
HasDNSSEC: serializableItem.HasDNSSEC,
Size: serializableItem.Size,
}
// 创建新的链表节点并添加到尾部
newNode := &LRUNode{
key: key,
value: item,
}
c.addNodeToTail(newNode)
c.cache[key] = newNode
c.cacheSize += int64(item.Size)
}
}
}
// cacheDir 返回缓存目录
func cacheDir() string {
return "data"
}
// SetMaxCacheTTL 设置最大缓存TTL
func (c *DNSCache) SetMaxCacheTTL(ttl time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.maxCacheTTL = ttl
}
// SetMinCacheTTL 设置最小缓存TTL
func (c *DNSCache) SetMinCacheTTL(ttl time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.minCacheTTL = ttl
}
// SetCacheMode 设置缓存模式
func (c *DNSCache) SetCacheMode(mode string) {
c.mutex.Lock()
oldMode := c.cacheMode
c.mutex.Unlock()
// 根据模式变化决定是否启动或停止保存循环
if oldMode != mode {
if oldMode == "file" {
// 从file模式切换到其他模式先保存当前缓存到文件
// 直接调用底层保存逻辑,不检查缓存模式
c.saveCacheToFile()
}
c.mutex.Lock()
c.cacheMode = mode
c.mutex.Unlock()
if mode == "file" {
// 切换到file模式启动保存循环
c.startSaveLoop()
} else {
// 切换到非file模式停止保存循环
c.saveLoopMutex.Lock()
if c.saveRunning {
close(c.saveStopCh)
c.saveRunning = false
}
c.saveLoopMutex.Unlock()
}
}
}
// SetMaxCacheSize 设置最大缓存大小
func (c *DNSCache) SetMaxCacheSize(size int64) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.maxCacheSize = size
}
// cleanupExpired 清理过期的缓存项,并返回下一次清理间隔的建议值
func (c *DNSCache) cleanupExpired() time.Duration {
now := time.Now()
c.mutex.Lock()
defer c.mutex.Unlock()
// 收集所有过期的键
var expiredKeys []string
totalItems := len(c.cache)
// 遍历缓存,收集过期项
for key, node := range c.cache {
if now.After(node.value.Expiry) {
expiredKeys = append(expiredKeys, key)
}
}
expiredCount := len(expiredKeys)
// 智能清理策略
// 1. 如果过期项比例超过50%,立即清理
// 2. 如果缓存大小超过最大缓存大小的80%,清理过期项
// 3. 如果缓存项数量超过最大条目数的80%,清理过期项
needCleanup := false
if totalItems > 0 {
if float64(expiredCount)/float64(totalItems) > 0.5 {
needCleanup = true
} else if c.cacheSize > c.maxCacheSize*8/10 {
needCleanup = true
} else if totalItems > c.maxSize*8/10 {
needCleanup = true
}
}
// 如果没有过期项或不需要清理,根据过期项比例返回建议的清理间隔
if expiredCount == 0 || !needCleanup {
// 计算下一次清理间隔
var nextInterval time.Duration
if totalItems == 0 {
// 空缓存,下一次清理间隔可以长一些
nextInterval = 5 * time.Minute
} else {
expireRatio := float64(expiredCount) / float64(totalItems)
// 过期项比例越高,清理间隔越短
if expireRatio < 0.1 {
nextInterval = 5 * time.Minute
} else if expireRatio < 0.3 {
nextInterval = 2 * time.Minute
} else {
nextInterval = 1 * time.Minute
}
}
return nextInterval
}
// 删除过期的缓存项
for _, key := range expiredKeys {
if node, found := c.cache[key]; found {
// 减去缓存项大小
c.cacheSize -= int64(node.value.Size)
delete(c.cache, key)
c.removeNode(node)
}
}
// 清理后,如果缓存大小仍然超过最大缓存大小,继续清理最久未使用的项
if c.cacheSize > c.maxCacheSize {
// 计算需要清理的额外大小
overflow := c.cacheSize - c.maxCacheSize
cleanedSize := int64(0)
// 从链表头开始清理(最久未使用的项)
current := c.head
for current != nil && cleanedSize < overflow {
nextNode := current.next
cleanedSize += int64(current.value.Size)
// 删除节点
delete(c.cache, current.key)
c.removeNode(current)
current = nextNode
}
}
// 清理后,如果缓存项数量仍然超过最大条目数,继续清理最久未使用的项
if len(c.cache) > c.maxSize {
// 计算需要清理的额外数量
overflowCount := len(c.cache) - c.maxSize
// 从链表头开始清理(最久未使用的项)
current := c.head
for i := 0; i < overflowCount && current != nil; i++ {
nextNode := current.next
// 删除节点
c.cacheSize -= int64(current.value.Size)
delete(c.cache, current.key)
c.removeNode(current)
current = nextNode
}
}
// 清理后,根据剩余过期项比例返回建议的清理间隔
// 重新计算剩余过期项
var remainingExpired int
for _, node := range c.cache {
if now.After(node.value.Expiry) {
remainingExpired++
}
}
remainingItems := len(c.cache)
var nextInterval time.Duration
if remainingItems == 0 {
nextInterval = 5 * time.Minute
} else {
remainingRatio := float64(remainingExpired) / float64(remainingItems)
// 剩余过期项比例越高,清理间隔越短
if remainingRatio < 0.1 {
nextInterval = 5 * time.Minute
} else if remainingRatio < 0.3 {
nextInterval = 2 * time.Minute
} else {
nextInterval = 1 * time.Minute
}
}
return nextInterval
}