package dns import ( "encoding/json" "io/ioutil" "math" "os" "sync" "time" "github.com/miekg/dns" ) // DNSCacheItem 表示缓存中的DNS响应项 type DNSCacheItem struct { Response *dns.Msg // DNS响应消息 Expiry time.Time // 过期时间 HasDNSSEC bool // 是否包含DNSSEC记录 Size int // 缓存项大小(字节) } // SerializableDNSCacheItem 用于JSON序列化的缓存项结构 type SerializableDNSCacheItem struct { ResponseBytes []byte `json:"responseBytes"` // 二进制DNS响应 Expiry int64 `json:"expiry"` // 过期时间(纳秒) HasDNSSEC bool `json:"hasDNSSEC"` // 是否包含DNSSEC记录 Size int `json:"size"` // 缓存项大小(字节) } // SerializableDNSCache 可序列化的缓存结构 type SerializableDNSCache struct { Items map[string]*SerializableDNSCacheItem `json:"items"` // 缓存项 TTL int64 `json:"ttl"` // 默认TTL(纳秒) MaxSize int `json:"maxSize"` // 最大缓存大小 CacheMode string `json:"cacheMode"` // 缓存模式 CacheFilePath string `json:"cacheFilePath"` // 缓存文件路径 } // DNSCache DNS缓存结构 type DNSCache struct { cache map[string]*LRUNode // 缓存映射表,直接存储链表节点 mutex sync.RWMutex // 读写锁,保护缓存 ttl time.Duration // 默认缓存TTL maxSize int // 最大缓存条目数 cacheSize int64 // 当前缓存大小(字节) maxCacheSize int64 // 最大缓存大小(字节) cacheMode string // 缓存模式 cacheFilePath string // 缓存文件路径 saveInterval time.Duration // 保存间隔 saveMutex sync.Mutex // 保存互斥锁 maxCacheTTL time.Duration // 最大缓存TTL minCacheTTL time.Duration // 最小缓存TTL saveStopCh chan struct{} // 保存循环停止通道 saveRunning bool // 保存循环是否运行 saveLoopMutex sync.Mutex // 保护保存循环状态的互斥锁 // 双向链表头和尾指针,用于LRU淘汰 head *LRUNode // 头指针,指向最久未使用的节点 tail *LRUNode // 尾指针,指向最近使用的节点 // 缓存变化跟踪,用于智能保存 changeCount int // 缓存变化次数 lastSaveCacheSize int64 // 上次保存时的缓存大小 lastSaveItemCount int // 上次保存时的缓存项数量 lastSaveTime time.Time // 上次保存时间 minSaveInterval time.Duration // 最小保存间隔,避免过于频繁的保存 } // LRUNode 双向链表节点,用于LRU缓存 type LRUNode struct { key string value *DNSCacheItem prev *LRUNode next *LRUNode } // NewDNSCache 创建新的DNS缓存实例 func NewDNSCache(defaultTTL time.Duration, cacheMode string, cacheSizeMB int, cacheFilePath string, saveInterval time.Duration, maxCacheTTL, minCacheTTL time.Duration) *DNSCache { // 计算最大缓存大小(字节) maxCacheSize := int64(cacheSizeMB) * 1024 * 1024 cache := &DNSCache{ cache: make(map[string]*LRUNode), ttl: defaultTTL, maxSize: 10000, // 默认最大缓存10000条记录 cacheSize: 0, maxCacheSize: maxCacheSize, cacheMode: cacheMode, cacheFilePath: cacheFilePath, saveInterval: saveInterval, maxCacheTTL: maxCacheTTL, minCacheTTL: minCacheTTL, saveStopCh: make(chan struct{}), saveRunning: false, head: nil, tail: nil, changeCount: 0, lastSaveCacheSize: 0, lastSaveItemCount: 0, lastSaveTime: time.Now(), minSaveInterval: 30 * time.Second, // 最小保存间隔为30秒,避免过于频繁的保存 } // 加载现有缓存(如果存在) if cacheMode == "file" { cache.LoadFromFile() } // 启动缓存清理协程 go cache.startCleanupLoop() // 启动定期保存协程(如果是文件缓存) if cacheMode == "file" { go cache.startSaveLoop() } return cache } // addNodeToTail 将节点添加到链表尾部(表示最近使用) func (c *DNSCache) addNodeToTail(node *LRUNode) { if c.tail == nil { // 链表为空 c.head = node c.tail = node } else { // 添加到尾部 node.prev = c.tail c.tail.next = node c.tail = node } } // removeNode 从链表中移除指定节点 func (c *DNSCache) removeNode(node *LRUNode) { if node.prev != nil { node.prev.next = node.next } else { // 移除的是头节点 c.head = node.next } if node.next != nil { node.next.prev = node.prev } else { // 移除的是尾节点 c.tail = node.prev } // 清空节点的前后指针 node.prev = nil node.next = nil } // moveNodeToTail 将节点移动到链表尾部(表示最近使用) func (c *DNSCache) moveNodeToTail(node *LRUNode) { // 如果已经是尾节点,不需要移动 if node == c.tail { return } // 从链表中移除节点 c.removeNode(node) // 重新添加到尾部 c.addNodeToTail(node) } // cacheKey 生成缓存键 func cacheKey(qName string, qType uint16) string { return qName + "|" + dns.TypeToString[qType] } // calculateItemSize 计算缓存项大小 func calculateItemSize(item *DNSCacheItem) int { // 使用更高效的方式估算缓存项大小 // 避免使用json.Marshal和rr.String(),因为它们在高频调用时会消耗大量CPU资源 size := 0 // 估算Response大小 if item.Response != nil { // 粗略估算DNS消息大小 // 头部大小约12字节 size += 12 // 问题部分 for _, q := range item.Response.Question { size += len(q.Name) + 4 // 域名长度 + 类型(2) + 类(2) } // 高效估算资源记录大小,避免调用rr.String() estimateRRSize := func(rr dns.RR) int { rrSize := len(rr.Header().Name) + 10 // 域名 + 类型(2) + 类(2) + TTL(4) + 长度(2) switch rr.Header().Rrtype { case dns.TypeA: rrSize += 4 // IPv4地址 case dns.TypeAAAA: rrSize += 16 // IPv6地址 case dns.TypeCNAME, dns.TypePTR, dns.TypeNS: // 对于CNAME、PTR、NS记录,需要估算目标域名长度 if cname, ok := rr.(*dns.CNAME); ok { rrSize += len(cname.Target) } else if ptr, ok := rr.(*dns.PTR); ok { rrSize += len(ptr.Ptr) } else if ns, ok := rr.(*dns.NS); ok { rrSize += len(ns.Ns) } else { // 默认估算 rrSize += 30 } case dns.TypeMX: // MX记录:优先级(2) + 目标域名 rrSize += 2 + 30 // 默认30字节目标域名 case dns.TypeTXT: // TXT记录:文本长度 rrSize += 50 // 默认50字节文本 case dns.TypeSRV: // SRV记录:优先级(2) + 权重(2) + 端口(2) + 目标域名 rrSize += 6 + 30 // 默认30字节目标域名 case dns.TypeSOA: // SOA记录:主NS + 管理员邮箱 + 序列号(4) + 刷新时间(4) + 重试时间(4) + 过期时间(4) + 最小TTL(4) rrSize += 100 // 默认100字节 default: // 其他类型记录,使用默认估算 rrSize += 50 } return rrSize } // 回答部分 for _, rr := range item.Response.Answer { size += estimateRRSize(rr) } // 授权部分 for _, rr := range item.Response.Ns { size += estimateRRSize(rr) } // 附加部分 for _, rr := range item.Response.Extra { if rr.Header().Rrtype == dns.TypeOPT { // OPT记录大小约为40字节(EDNS0) size += 40 } else { size += estimateRRSize(rr) } } } // 其他字段大小 size += 8 // Expiry size += 1 // HasDNSSEC return size } // hasDNSSECRecords 检查响应是否包含DNSSEC记录 func hasDNSSECRecords(response *dns.Msg) bool { // 直接在循环中检查RR类型,避免创建匿名函数的开销 // 检查回答部分 for _, rr := range response.Answer { switch rr.(type) { case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3: return true } } // 检查授权部分 for _, rr := range response.Ns { switch rr.(type) { case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3: return true } } // 检查附加部分 for _, rr := range response.Extra { switch rr.(type) { case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3: return true } } return false } // Set 设置缓存项 func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) { // 设置默认TTL if ttl <= 0 { ttl = c.ttl } // 应用maxCacheTTL和minCacheTTL约束 if c.maxCacheTTL > 0 && ttl > c.maxCacheTTL { ttl = c.maxCacheTTL } if c.minCacheTTL > 0 && ttl < c.minCacheTTL { ttl = c.minCacheTTL } key := cacheKey(qName, qType) item := &DNSCacheItem{ Response: response.Copy(), // 复制响应以避免外部修改 Expiry: time.Now().Add(ttl), HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志 } // 计算缓存项大小 item.Size = calculateItemSize(item) c.mutex.Lock() defer c.mutex.Unlock() // 如果条目已存在,先从链表和缓存中移除,并更新缓存大小 if existingNode, found := c.cache[key]; found { c.cacheSize -= int64(existingNode.value.Size) c.removeNode(existingNode) delete(c.cache, key) } // 创建新的链表节点并添加到尾部 newNode := &LRUNode{ key: key, value: item, } c.addNodeToTail(newNode) c.cache[key] = newNode c.cacheSize += int64(item.Size) // 检查是否超过最大条目数限制,如果超过则移除最久未使用的条目 if len(c.cache) > c.maxSize { // 最久未使用的条目是链表的头节点 if c.head != nil { c.cacheSize -= int64(c.head.value.Size) oldestKey := c.head.key // 从缓存和链表中移除头节点 delete(c.cache, oldestKey) c.removeNode(c.head) } } // 检查是否超过最大缓存大小,如果超过则继续移除最久未使用的条目 for c.cacheSize > c.maxCacheSize && c.head != nil { c.cacheSize -= int64(c.head.value.Size) oldestKey := c.head.key delete(c.cache, oldestKey) c.removeNode(c.head) } // 更新缓存变化计数 c.changeCount++ } // Get 获取缓存项 func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) { key := cacheKey(qName, qType) // 首先使用读锁检查缓存项是否存在和是否过期 c.mutex.RLock() node, found := c.cache[key] if !found { c.mutex.RUnlock() return nil, false } // 检查是否过期 if time.Now().After(node.value.Expiry) { c.mutex.RUnlock() // 需要删除过期条目,使用写锁 c.mutex.Lock() // 再次检查,防止在读写锁切换期间被其他协程处理 if node, stillExists := c.cache[key]; stillExists && time.Now().After(node.value.Expiry) { delete(c.cache, key) c.removeNode(node) } c.mutex.Unlock() return nil, false } // 返回前释放读锁,避免长时间持有锁 response := node.value.Response.Copy() c.mutex.RUnlock() // 标记为最近使用需要修改链表,使用写锁 c.mutex.Lock() // 再次检查节点是否存在,防止在读写锁切换期间被删除 if node, stillExists := c.cache[key]; stillExists { c.moveNodeToTail(node) } c.mutex.Unlock() return response, true } // delete 删除缓存项 func (c *DNSCache) delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() // 从缓存和链表中删除 if node, found := c.cache[key]; found { delete(c.cache, key) c.removeNode(node) } } // Clear 清空缓存 func (c *DNSCache) Clear() { c.mutex.Lock() c.cache = make(map[string]*LRUNode) // 重置链表指针 c.head = nil c.tail = nil c.mutex.Unlock() } // Size 获取缓存大小 func (c *DNSCache) Size() int { c.mutex.RLock() defer c.mutex.RUnlock() return len(c.cache) } // startCleanupLoop 启动定期清理过期缓存的协程 func (c *DNSCache) startCleanupLoop() { // 初始清理间隔为1分钟 cleanupInterval := time.Minute * 1 ticker := time.NewTicker(cleanupInterval) defer ticker.Stop() for range ticker.C { cleanupInterval = c.cleanupExpired() // 调整下次清理间隔,范围:15秒到5分钟 if cleanupInterval < 15*time.Second { cleanupInterval = 15 * time.Second } else if cleanupInterval > 5*time.Minute { cleanupInterval = 5 * time.Minute } // 更新清理间隔 ticker.Reset(cleanupInterval) } } // startSaveLoop 启动定期保存缓存的协程 func (c *DNSCache) startSaveLoop() { c.saveLoopMutex.Lock() // 如果已经在运行,直接返回 if c.saveRunning { c.saveLoopMutex.Unlock() return } // 重置停止通道 c.saveStopCh = make(chan struct{}) c.saveRunning = true c.saveLoopMutex.Unlock() go func() { ticker := time.NewTicker(c.saveInterval) // 根据配置的间隔保存 defer ticker.Stop() for { select { case <-ticker.C: // 检查缓存模式,如果不是file模式则不保存 c.mutex.RLock() mode := c.cacheMode c.mutex.RUnlock() if mode == "file" { c.SaveToFile() } case <-c.saveStopCh: // 停止保存循环 c.saveLoopMutex.Lock() c.saveRunning = false c.saveLoopMutex.Unlock() return } } }() } // saveCacheToFile 保存缓存到文件的底层实现,不检查缓存模式 func (c *DNSCache) saveCacheToFile() { c.saveMutex.Lock() defer c.saveMutex.Unlock() // 智能保存策略 // 1. 如果缓存变化次数少于10次,跳过保存 // 2. 如果距离上次保存时间不足最小保存间隔,跳过保存 c.mutex.RLock() changeCount := c.changeCount lastSaveTime := c.lastSaveTime lastSaveCacheSize := c.lastSaveCacheSize lastSaveItemCount := c.lastSaveItemCount currentCacheSize := c.cacheSize currentItemCount := len(c.cache) c.mutex.RUnlock() // 检查是否需要保存 if changeCount < 10 { return } if time.Since(lastSaveTime) < c.minSaveInterval { return } if currentItemCount > 0 { cacheSizeChange := float64(currentCacheSize-lastSaveCacheSize) / float64(lastSaveCacheSize+1) // +1避免除以零 itemCountChange := float64(currentItemCount-lastSaveItemCount) / float64(lastSaveItemCount+1) // +1避免除以零 if math.Abs(cacheSizeChange) < 0.1 && math.Abs(itemCountChange) < 0.1 { return } } // 开始保存缓存 c.mutex.RLock() // 收集有效的缓存项 validItems := make(map[string]*SerializableDNSCacheItem) now := time.Now() for key, node := range c.cache { // 只保存未过期的缓存项 if now.Before(node.value.Expiry) { // 序列化DNS响应为二进制 responseBytes, err := node.value.Response.Pack() if err != nil { continue // 跳过无法序列化的响应 } // 创建可序列化的缓存项 validItems[key] = &SerializableDNSCacheItem{ ResponseBytes: responseBytes, Expiry: node.value.Expiry.UnixNano(), HasDNSSEC: node.value.HasDNSSEC, Size: node.value.Size, } } } // 创建可序列化的缓存结构 serializableCache := &SerializableDNSCache{ Items: validItems, TTL: int64(c.ttl), MaxSize: c.maxSize, CacheMode: c.cacheMode, CacheFilePath: c.cacheFilePath, } c.mutex.RUnlock() // 序列化到JSON data, err := json.MarshalIndent(serializableCache, "", " ") if err != nil { return } // 确保目录存在 os.MkdirAll(cacheDir(), 0755) // 保存到文件 err = ioutil.WriteFile(c.cacheFilePath, data, 0644) if err != nil { return } // 更新保存状态 c.mutex.Lock() c.changeCount = 0 c.lastSaveTime = time.Now() c.lastSaveCacheSize = currentCacheSize c.lastSaveItemCount = currentItemCount c.mutex.Unlock() } // SaveToFile 保存缓存到文件 func (c *DNSCache) SaveToFile() { // 检查缓存模式,如果不是file模式,直接返回 c.mutex.RLock() mode := c.cacheMode c.mutex.RUnlock() if mode != "file" { return } // 调用底层保存逻辑 c.saveCacheToFile() } // LoadFromFile 从文件加载缓存 func (c *DNSCache) LoadFromFile() { c.mutex.Lock() defer c.mutex.Unlock() // 检查文件是否存在 if _, err := os.Stat(c.cacheFilePath); os.IsNotExist(err) { return // 文件不存在,跳过加载 } // 读取文件内容 data, err := ioutil.ReadFile(c.cacheFilePath) if err != nil { return } // 反序列化JSON var serializableCache SerializableDNSCache err = json.Unmarshal(data, &serializableCache) if err != nil { return } // 加载缓存项 now := time.Now() for key, serializableItem := range serializableCache.Items { // 转换过期时间 expiry := time.Unix(0, serializableItem.Expiry) // 只加载未过期的缓存项 if now.Before(expiry) { // 反序列化二进制DNS响应 response := &dns.Msg{} err := response.Unpack(serializableItem.ResponseBytes) if err != nil { continue // 跳过无法反序列化的响应 } // 创建缓存项 item := &DNSCacheItem{ Response: response, Expiry: expiry, HasDNSSEC: serializableItem.HasDNSSEC, Size: serializableItem.Size, } // 创建新的链表节点并添加到尾部 newNode := &LRUNode{ key: key, value: item, } c.addNodeToTail(newNode) c.cache[key] = newNode c.cacheSize += int64(item.Size) } } } // cacheDir 返回缓存目录 func cacheDir() string { return "data" } // SetMaxCacheTTL 设置最大缓存TTL func (c *DNSCache) SetMaxCacheTTL(ttl time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() c.maxCacheTTL = ttl } // SetMinCacheTTL 设置最小缓存TTL func (c *DNSCache) SetMinCacheTTL(ttl time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() c.minCacheTTL = ttl } // SetCacheMode 设置缓存模式 func (c *DNSCache) SetCacheMode(mode string) { c.mutex.Lock() oldMode := c.cacheMode c.mutex.Unlock() // 根据模式变化决定是否启动或停止保存循环 if oldMode != mode { if oldMode == "file" { // 从file模式切换到其他模式,先保存当前缓存到文件 // 直接调用底层保存逻辑,不检查缓存模式 c.saveCacheToFile() } c.mutex.Lock() c.cacheMode = mode c.mutex.Unlock() if mode == "file" { // 切换到file模式,启动保存循环 c.startSaveLoop() } else { // 切换到非file模式,停止保存循环 c.saveLoopMutex.Lock() if c.saveRunning { close(c.saveStopCh) c.saveRunning = false } c.saveLoopMutex.Unlock() } } } // SetMaxCacheSize 设置最大缓存大小 func (c *DNSCache) SetMaxCacheSize(size int64) { c.mutex.Lock() defer c.mutex.Unlock() c.maxCacheSize = size } // cleanupExpired 清理过期的缓存项,并返回下一次清理间隔的建议值 func (c *DNSCache) cleanupExpired() time.Duration { now := time.Now() c.mutex.Lock() defer c.mutex.Unlock() // 收集所有过期的键 var expiredKeys []string totalItems := len(c.cache) // 遍历缓存,收集过期项 for key, node := range c.cache { if now.After(node.value.Expiry) { expiredKeys = append(expiredKeys, key) } } expiredCount := len(expiredKeys) // 智能清理策略 // 1. 如果过期项比例超过50%,立即清理 // 2. 如果缓存大小超过最大缓存大小的80%,清理过期项 // 3. 如果缓存项数量超过最大条目数的80%,清理过期项 needCleanup := false if totalItems > 0 { if float64(expiredCount)/float64(totalItems) > 0.5 { needCleanup = true } else if c.cacheSize > c.maxCacheSize*8/10 { needCleanup = true } else if totalItems > c.maxSize*8/10 { needCleanup = true } } // 如果没有过期项或不需要清理,根据过期项比例返回建议的清理间隔 if expiredCount == 0 || !needCleanup { // 计算下一次清理间隔 var nextInterval time.Duration if totalItems == 0 { // 空缓存,下一次清理间隔可以长一些 nextInterval = 5 * time.Minute } else { expireRatio := float64(expiredCount) / float64(totalItems) // 过期项比例越高,清理间隔越短 if expireRatio < 0.1 { nextInterval = 5 * time.Minute } else if expireRatio < 0.3 { nextInterval = 2 * time.Minute } else { nextInterval = 1 * time.Minute } } return nextInterval } // 删除过期的缓存项 for _, key := range expiredKeys { if node, found := c.cache[key]; found { // 减去缓存项大小 c.cacheSize -= int64(node.value.Size) delete(c.cache, key) c.removeNode(node) } } // 清理后,如果缓存大小仍然超过最大缓存大小,继续清理最久未使用的项 if c.cacheSize > c.maxCacheSize { // 计算需要清理的额外大小 overflow := c.cacheSize - c.maxCacheSize cleanedSize := int64(0) // 从链表头开始清理(最久未使用的项) current := c.head for current != nil && cleanedSize < overflow { nextNode := current.next cleanedSize += int64(current.value.Size) // 删除节点 delete(c.cache, current.key) c.removeNode(current) current = nextNode } } // 清理后,如果缓存项数量仍然超过最大条目数,继续清理最久未使用的项 if len(c.cache) > c.maxSize { // 计算需要清理的额外数量 overflowCount := len(c.cache) - c.maxSize // 从链表头开始清理(最久未使用的项) current := c.head for i := 0; i < overflowCount && current != nil; i++ { nextNode := current.next // 删除节点 c.cacheSize -= int64(current.value.Size) delete(c.cache, current.key) c.removeNode(current) current = nextNode } } // 清理后,根据剩余过期项比例返回建议的清理间隔 // 重新计算剩余过期项 var remainingExpired int for _, node := range c.cache { if now.After(node.value.Expiry) { remainingExpired++ } } remainingItems := len(c.cache) var nextInterval time.Duration if remainingItems == 0 { nextInterval = 5 * time.Minute } else { remainingRatio := float64(remainingExpired) / float64(remainingItems) // 剩余过期项比例越高,清理间隔越短 if remainingRatio < 0.1 { nextInterval = 5 * time.Minute } else if remainingRatio < 0.3 { nextInterval = 2 * time.Minute } else { nextInterval = 1 * time.Minute } } return nextInterval }