package dns import ( "sync" "time" "github.com/miekg/dns" ) // DNSCacheItem 表示缓存中的DNS响应项 type DNSCacheItem struct { Response *dns.Msg // DNS响应消息 Expiry time.Time // 过期时间 HasDNSSEC bool // 是否包含DNSSEC记录 } // LRUNode 双向链表节点,用于LRU缓存 type LRUNode struct { key string value *DNSCacheItem prev *LRUNode next *LRUNode } // DNSCache DNS缓存结构 type DNSCache struct { cache map[string]*LRUNode // 缓存映射表,直接存储链表节点 mutex sync.RWMutex // 读写锁,保护缓存 defaultTTL time.Duration // 默认缓存TTL maxSize int // 最大缓存条目数 // 双向链表头和尾指针,用于LRU淘汰 head *LRUNode // 头指针,指向最久未使用的节点 tail *LRUNode // 尾指针,指向最近使用的节点 } // NewDNSCache 创建新的DNS缓存实例 func NewDNSCache(defaultTTL time.Duration) *DNSCache { cache := &DNSCache{ cache: make(map[string]*LRUNode), defaultTTL: defaultTTL, maxSize: 10000, // 默认最大缓存10000条记录 head: nil, tail: nil, } // 启动缓存清理协程 go cache.startCleanupLoop() return cache } // addNodeToTail 将节点添加到链表尾部(表示最近使用) func (c *DNSCache) addNodeToTail(node *LRUNode) { if c.tail == nil { // 链表为空 c.head = node c.tail = node } else { // 添加到尾部 node.prev = c.tail c.tail.next = node c.tail = node } } // removeNode 从链表中移除指定节点 func (c *DNSCache) removeNode(node *LRUNode) { if node.prev != nil { node.prev.next = node.next } else { // 移除的是头节点 c.head = node.next } if node.next != nil { node.next.prev = node.prev } else { // 移除的是尾节点 c.tail = node.prev } // 清空节点的前后指针 node.prev = nil node.next = nil } // moveNodeToTail 将节点移动到链表尾部(表示最近使用) func (c *DNSCache) moveNodeToTail(node *LRUNode) { // 如果已经是尾节点,不需要移动 if node == c.tail { return } // 从链表中移除节点 c.removeNode(node) // 重新添加到尾部 c.addNodeToTail(node) } // cacheKey 生成缓存键 func cacheKey(qName string, qType uint16) string { return qName + "|" + dns.TypeToString[qType] } // hasDNSSECRecords 检查响应是否包含DNSSEC记录 func hasDNSSECRecords(response *dns.Msg) bool { // 定义检查单个RR是否为DNSSEC记录的辅助函数 isDNSSECRecord := func(rr dns.RR) bool { switch rr.(type) { case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3: return true default: return false } } // 检查响应中是否包含DNSSEC相关记录 for _, rr := range response.Answer { if isDNSSECRecord(rr) { return true } } for _, rr := range response.Ns { if isDNSSECRecord(rr) { return true } } for _, rr := range response.Extra { if isDNSSECRecord(rr) { return true } } return false } // Set 设置缓存项 func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) { if ttl <= 0 { ttl = c.defaultTTL } key := cacheKey(qName, qType) item := &DNSCacheItem{ Response: response.Copy(), // 复制响应以避免外部修改 Expiry: time.Now().Add(ttl), HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志 } c.mutex.Lock() defer c.mutex.Unlock() // 如果条目已存在,先从链表和缓存中移除 if existingNode, found := c.cache[key]; found { c.removeNode(existingNode) delete(c.cache, key) } // 创建新的链表节点并添加到尾部 newNode := &LRUNode{ key: key, value: item, } c.addNodeToTail(newNode) c.cache[key] = newNode // 检查是否超过最大大小限制,如果超过则移除最久未使用的条目 if len(c.cache) > c.maxSize { // 最久未使用的条目是链表的头节点 if c.head != nil { oldestKey := c.head.key // 从缓存和链表中移除头节点 delete(c.cache, oldestKey) c.removeNode(c.head) } } } // Get 获取缓存项 func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) { key := cacheKey(qName, qType) // 首先使用读锁检查缓存项是否存在和是否过期 c.mutex.RLock() node, found := c.cache[key] if !found { c.mutex.RUnlock() return nil, false } // 检查是否过期 if time.Now().After(node.value.Expiry) { c.mutex.RUnlock() // 需要删除过期条目,使用写锁 c.mutex.Lock() // 再次检查,防止在读写锁切换期间被其他协程处理 if node, stillExists := c.cache[key]; stillExists && time.Now().After(node.value.Expiry) { delete(c.cache, key) c.removeNode(node) } c.mutex.Unlock() return nil, false } // 返回前释放读锁,避免长时间持有锁 response := node.value.Response.Copy() c.mutex.RUnlock() // 标记为最近使用需要修改链表,使用写锁 c.mutex.Lock() // 再次检查节点是否存在,防止在读写锁切换期间被删除 if node, stillExists := c.cache[key]; stillExists { c.moveNodeToTail(node) } c.mutex.Unlock() return response, true } // delete 删除缓存项 func (c *DNSCache) delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() // 从缓存和链表中删除 if node, found := c.cache[key]; found { delete(c.cache, key) c.removeNode(node) } } // Clear 清空缓存 func (c *DNSCache) Clear() { c.mutex.Lock() c.cache = make(map[string]*LRUNode) // 重置链表指针 c.head = nil c.tail = nil c.mutex.Unlock() } // Size 获取缓存大小 func (c *DNSCache) Size() int { c.mutex.RLock() defer c.mutex.RUnlock() return len(c.cache) } // startCleanupLoop 启动定期清理过期缓存的协程 func (c *DNSCache) startCleanupLoop() { ticker := time.NewTicker(time.Minute * 1) // 每1分钟清理一次,减少内存占用 defer ticker.Stop() for range ticker.C { c.cleanupExpired() } } // cleanupExpired 清理过期的缓存项 func (c *DNSCache) cleanupExpired() { now := time.Now() c.mutex.Lock() defer c.mutex.Unlock() // 收集所有过期的键 var expiredKeys []string for key, node := range c.cache { if now.After(node.value.Expiry) { expiredKeys = append(expiredKeys, key) } } // 删除过期的缓存项 for _, key := range expiredKeys { if node, found := c.cache[key]; found { delete(c.cache, key) c.removeNode(node) } } }