Files
dns-server/dns/cache.go
2026-01-03 01:11:42 +08:00

254 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"sync"
"time"
"github.com/miekg/dns"
)
// DNSCacheItem 表示缓存中的DNS响应项
type DNSCacheItem struct {
Response *dns.Msg // DNS响应消息
Expiry time.Time // 过期时间
HasDNSSEC bool // 是否包含DNSSEC记录
}
// DNSCache DNS缓存结构
type DNSCache struct {
cache map[string]*DNSCacheItem // 缓存映射表
mutex sync.RWMutex // 读写锁,保护缓存
defaultTTL time.Duration // 默认缓存TTL
maxSize int // 最大缓存条目数
// 使用链表结构来跟踪缓存条目的访问顺序用于LRU淘汰
accessList []string // 记录访问顺序,最新访问的放在最后
}
// NewDNSCache 创建新的DNS缓存实例
func NewDNSCache(defaultTTL time.Duration) *DNSCache {
cache := &DNSCache{
cache: make(map[string]*DNSCacheItem),
defaultTTL: defaultTTL,
maxSize: 10000, // 默认最大缓存10000条记录
accessList: make([]string, 0, 10000),
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// cacheKey 生成缓存键
func cacheKey(qName string, qType uint16) string {
return qName + "|" + dns.TypeToString[qType]
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func hasDNSSECRecords(response *dns.Msg) bool {
// 检查响应中是否包含DNSSEC相关记录DNSKEY、RRSIG、DS、NSEC、NSEC3等
for _, rr := range response.Answer {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Ns {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
for _, rr := range response.Extra {
if _, ok := rr.(*dns.DNSKEY); ok {
return true
}
if _, ok := rr.(*dns.RRSIG); ok {
return true
}
if _, ok := rr.(*dns.DS); ok {
return true
}
if _, ok := rr.(*dns.NSEC); ok {
return true
}
if _, ok := rr.(*dns.NSEC3); ok {
return true
}
}
return false
}
// Set 设置缓存项
func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) {
if ttl <= 0 {
ttl = c.defaultTTL
}
key := cacheKey(qName, qType)
item := &DNSCacheItem{
Response: response.Copy(), // 复制响应以避免外部修改
Expiry: time.Now().Add(ttl),
HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志
}
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果条目已存在,先从访问列表中移除
for i, k := range c.accessList {
if k == key {
// 移除旧位置
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
// 将新条目添加到访问列表末尾
c.accessList = append(c.accessList, key)
c.cache[key] = item
// 检查是否超过最大大小限制,如果超过则移除最久未使用的条目
if len(c.cache) > c.maxSize {
// 最久未使用的条目是访问列表的第一个
oldestKey := c.accessList[0]
// 从缓存和访问列表中移除
delete(c.cache, oldestKey)
c.accessList = c.accessList[1:]
}
}
// Get 获取缓存项
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
key := cacheKey(qName, qType)
c.mutex.Lock()
defer c.mutex.Unlock()
item, found := c.cache[key]
if !found {
return nil, false
}
// 检查是否过期
if time.Now().After(item.Expiry) {
// 过期了,删除缓存项
delete(c.cache, key)
// 从访问列表中移除
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
return nil, false
}
// 将访问的条目移动到访问列表末尾(标记为最近使用)
for i, k := range c.accessList {
if k == key {
// 移除旧位置
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
// 添加到末尾
c.accessList = append(c.accessList, key)
break
}
}
// 返回缓存的响应副本
response := item.Response.Copy()
return response, true
}
// delete 删除缓存项
func (c *DNSCache) delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 从缓存中删除
delete(c.cache, key)
// 从访问列表中移除
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
// Clear 清空缓存
func (c *DNSCache) Clear() {
c.mutex.Lock()
c.cache = make(map[string]*DNSCacheItem)
c.accessList = make([]string, 0, c.maxSize) // 重置访问列表
c.mutex.Unlock()
}
// Size 获取缓存大小
func (c *DNSCache) Size() int {
c.mutex.RLock()
defer c.mutex.RUnlock()
return len(c.cache)
}
// startCleanupLoop 启动定期清理过期缓存的协程
func (c *DNSCache) startCleanupLoop() {
ticker := time.NewTicker(time.Minute * 5) // 每5分钟清理一次
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期的缓存项
func (c *DNSCache) cleanupExpired() {
now := time.Now()
c.mutex.Lock()
defer c.mutex.Unlock()
// 收集所有过期的键
var expiredKeys []string
for key, item := range c.cache {
if now.After(item.Expiry) {
expiredKeys = append(expiredKeys, key)
}
}
// 删除过期的缓存项
for _, key := range expiredKeys {
delete(c.cache, key)
// 从访问列表中移除
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
}