package dns import ( "sync" "time" "github.com/miekg/dns" ) // DNSCacheItem 表示缓存中的DNS响应项 type DNSCacheItem struct { Response *dns.Msg // DNS响应消息 Expiry time.Time // 过期时间 HasDNSSEC bool // 是否包含DNSSEC记录 } // DNSCache DNS缓存结构 type DNSCache struct { cache map[string]*DNSCacheItem // 缓存映射表 mutex sync.RWMutex // 读写锁,保护缓存 defaultTTL time.Duration // 默认缓存TTL maxSize int // 最大缓存条目数 // 使用链表结构来跟踪缓存条目的访问顺序,用于LRU淘汰 accessList []string // 记录访问顺序,最新访问的放在最后 } // NewDNSCache 创建新的DNS缓存实例 func NewDNSCache(defaultTTL time.Duration) *DNSCache { cache := &DNSCache{ cache: make(map[string]*DNSCacheItem), defaultTTL: defaultTTL, maxSize: 10000, // 默认最大缓存10000条记录 accessList: make([]string, 0, 10000), } // 启动缓存清理协程 go cache.startCleanupLoop() return cache } // cacheKey 生成缓存键 func cacheKey(qName string, qType uint16) string { return qName + "|" + dns.TypeToString[qType] } // hasDNSSECRecords 检查响应是否包含DNSSEC记录 func hasDNSSECRecords(response *dns.Msg) bool { // 检查响应中是否包含DNSSEC相关记录(DNSKEY、RRSIG、DS、NSEC、NSEC3等) for _, rr := range response.Answer { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } for _, rr := range response.Ns { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } for _, rr := range response.Extra { if _, ok := rr.(*dns.DNSKEY); ok { return true } if _, ok := rr.(*dns.RRSIG); ok { return true } if _, ok := rr.(*dns.DS); ok { return true } if _, ok := rr.(*dns.NSEC); ok { return true } if _, ok := rr.(*dns.NSEC3); ok { return true } } return false } // Set 设置缓存项 func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) { if ttl <= 0 { ttl = c.defaultTTL } key := cacheKey(qName, qType) item := &DNSCacheItem{ Response: response.Copy(), // 复制响应以避免外部修改 Expiry: time.Now().Add(ttl), HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志 } c.mutex.Lock() defer c.mutex.Unlock() // 如果条目已存在,先从访问列表中移除 for i, k := range c.accessList { if k == key { // 移除旧位置 c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) break } } // 将新条目添加到访问列表末尾 c.accessList = append(c.accessList, key) c.cache[key] = item // 检查是否超过最大大小限制,如果超过则移除最久未使用的条目 if len(c.cache) > c.maxSize { // 最久未使用的条目是访问列表的第一个 oldestKey := c.accessList[0] // 从缓存和访问列表中移除 delete(c.cache, oldestKey) c.accessList = c.accessList[1:] } } // Get 获取缓存项 func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) { key := cacheKey(qName, qType) c.mutex.Lock() defer c.mutex.Unlock() item, found := c.cache[key] if !found { return nil, false } // 检查是否过期 if time.Now().After(item.Expiry) { // 过期了,删除缓存项 delete(c.cache, key) // 从访问列表中移除 for i, k := range c.accessList { if k == key { c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) break } } return nil, false } // 将访问的条目移动到访问列表末尾(标记为最近使用) for i, k := range c.accessList { if k == key { // 移除旧位置 c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) // 添加到末尾 c.accessList = append(c.accessList, key) break } } // 返回缓存的响应副本 response := item.Response.Copy() return response, true } // delete 删除缓存项 func (c *DNSCache) delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() // 从缓存中删除 delete(c.cache, key) // 从访问列表中移除 for i, k := range c.accessList { if k == key { c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) break } } } // Clear 清空缓存 func (c *DNSCache) Clear() { c.mutex.Lock() c.cache = make(map[string]*DNSCacheItem) c.accessList = make([]string, 0, c.maxSize) // 重置访问列表 c.mutex.Unlock() } // Size 获取缓存大小 func (c *DNSCache) Size() int { c.mutex.RLock() defer c.mutex.RUnlock() return len(c.cache) } // startCleanupLoop 启动定期清理过期缓存的协程 func (c *DNSCache) startCleanupLoop() { ticker := time.NewTicker(time.Minute * 5) // 每5分钟清理一次 defer ticker.Stop() for range ticker.C { c.cleanupExpired() } } // cleanupExpired 清理过期的缓存项 func (c *DNSCache) cleanupExpired() { now := time.Now() c.mutex.Lock() defer c.mutex.Unlock() // 收集所有过期的键 var expiredKeys []string for key, item := range c.cache { if now.After(item.Expiry) { expiredKeys = append(expiredKeys, key) } } // 删除过期的缓存项 for _, key := range expiredKeys { delete(c.cache, key) // 从访问列表中移除 for i, k := range c.accessList { if k == key { c.accessList = append(c.accessList[:i], c.accessList[i+1:]...) break } } } }