254 lines
5.7 KiB
Go
254 lines
5.7 KiB
Go
package dns
|
||
|
||
import (
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// DNSCacheItem 表示缓存中的DNS响应项
|
||
type DNSCacheItem struct {
|
||
Response *dns.Msg // DNS响应消息
|
||
Expiry time.Time // 过期时间
|
||
HasDNSSEC bool // 是否包含DNSSEC记录
|
||
}
|
||
|
||
// DNSCache DNS缓存结构
|
||
type DNSCache struct {
|
||
cache map[string]*DNSCacheItem // 缓存映射表
|
||
mutex sync.RWMutex // 读写锁,保护缓存
|
||
defaultTTL time.Duration // 默认缓存TTL
|
||
maxSize int // 最大缓存条目数
|
||
// 使用链表结构来跟踪缓存条目的访问顺序,用于LRU淘汰
|
||
accessList []string // 记录访问顺序,最新访问的放在最后
|
||
}
|
||
|
||
// NewDNSCache 创建新的DNS缓存实例
|
||
func NewDNSCache(defaultTTL time.Duration) *DNSCache {
|
||
cache := &DNSCache{
|
||
cache: make(map[string]*DNSCacheItem),
|
||
defaultTTL: defaultTTL,
|
||
maxSize: 10000, // 默认最大缓存10000条记录
|
||
accessList: make([]string, 0, 10000),
|
||
}
|
||
|
||
// 启动缓存清理协程
|
||
go cache.startCleanupLoop()
|
||
|
||
return cache
|
||
}
|
||
|
||
// cacheKey 生成缓存键
|
||
func cacheKey(qName string, qType uint16) string {
|
||
return qName + "|" + dns.TypeToString[qType]
|
||
}
|
||
|
||
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
|
||
func hasDNSSECRecords(response *dns.Msg) bool {
|
||
// 检查响应中是否包含DNSSEC相关记录(DNSKEY、RRSIG、DS、NSEC、NSEC3等)
|
||
for _, rr := range response.Answer {
|
||
if _, ok := rr.(*dns.DNSKEY); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.RRSIG); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.DS); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC3); ok {
|
||
return true
|
||
}
|
||
}
|
||
for _, rr := range response.Ns {
|
||
if _, ok := rr.(*dns.DNSKEY); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.RRSIG); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.DS); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC3); ok {
|
||
return true
|
||
}
|
||
}
|
||
for _, rr := range response.Extra {
|
||
if _, ok := rr.(*dns.DNSKEY); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.RRSIG); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.DS); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC); ok {
|
||
return true
|
||
}
|
||
if _, ok := rr.(*dns.NSEC3); ok {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// Set 设置缓存项
|
||
func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) {
|
||
if ttl <= 0 {
|
||
ttl = c.defaultTTL
|
||
}
|
||
|
||
key := cacheKey(qName, qType)
|
||
item := &DNSCacheItem{
|
||
Response: response.Copy(), // 复制响应以避免外部修改
|
||
Expiry: time.Now().Add(ttl),
|
||
HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志
|
||
}
|
||
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
// 如果条目已存在,先从访问列表中移除
|
||
for i, k := range c.accessList {
|
||
if k == key {
|
||
// 移除旧位置
|
||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
|
||
// 将新条目添加到访问列表末尾
|
||
c.accessList = append(c.accessList, key)
|
||
c.cache[key] = item
|
||
|
||
// 检查是否超过最大大小限制,如果超过则移除最久未使用的条目
|
||
if len(c.cache) > c.maxSize {
|
||
// 最久未使用的条目是访问列表的第一个
|
||
oldestKey := c.accessList[0]
|
||
// 从缓存和访问列表中移除
|
||
delete(c.cache, oldestKey)
|
||
c.accessList = c.accessList[1:]
|
||
}
|
||
}
|
||
|
||
// Get 获取缓存项
|
||
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
|
||
key := cacheKey(qName, qType)
|
||
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
item, found := c.cache[key]
|
||
if !found {
|
||
return nil, false
|
||
}
|
||
|
||
// 检查是否过期
|
||
if time.Now().After(item.Expiry) {
|
||
// 过期了,删除缓存项
|
||
delete(c.cache, key)
|
||
// 从访问列表中移除
|
||
for i, k := range c.accessList {
|
||
if k == key {
|
||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// 将访问的条目移动到访问列表末尾(标记为最近使用)
|
||
for i, k := range c.accessList {
|
||
if k == key {
|
||
// 移除旧位置
|
||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||
// 添加到末尾
|
||
c.accessList = append(c.accessList, key)
|
||
break
|
||
}
|
||
}
|
||
|
||
// 返回缓存的响应副本
|
||
response := item.Response.Copy()
|
||
|
||
return response, true
|
||
}
|
||
|
||
// delete 删除缓存项
|
||
func (c *DNSCache) delete(key string) {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
// 从缓存中删除
|
||
delete(c.cache, key)
|
||
// 从访问列表中移除
|
||
for i, k := range c.accessList {
|
||
if k == key {
|
||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// Clear 清空缓存
|
||
func (c *DNSCache) Clear() {
|
||
c.mutex.Lock()
|
||
c.cache = make(map[string]*DNSCacheItem)
|
||
c.accessList = make([]string, 0, c.maxSize) // 重置访问列表
|
||
c.mutex.Unlock()
|
||
}
|
||
|
||
// Size 获取缓存大小
|
||
func (c *DNSCache) Size() int {
|
||
c.mutex.RLock()
|
||
defer c.mutex.RUnlock()
|
||
return len(c.cache)
|
||
}
|
||
|
||
// startCleanupLoop 启动定期清理过期缓存的协程
|
||
func (c *DNSCache) startCleanupLoop() {
|
||
ticker := time.NewTicker(time.Minute * 5) // 每5分钟清理一次
|
||
defer ticker.Stop()
|
||
|
||
for range ticker.C {
|
||
c.cleanupExpired()
|
||
}
|
||
}
|
||
|
||
// cleanupExpired 清理过期的缓存项
|
||
func (c *DNSCache) cleanupExpired() {
|
||
now := time.Now()
|
||
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
// 收集所有过期的键
|
||
var expiredKeys []string
|
||
for key, item := range c.cache {
|
||
if now.After(item.Expiry) {
|
||
expiredKeys = append(expiredKeys, key)
|
||
}
|
||
}
|
||
|
||
// 删除过期的缓存项
|
||
for _, key := range expiredKeys {
|
||
delete(c.cache, key)
|
||
// 从访问列表中移除
|
||
for i, k := range c.accessList {
|
||
if k == key {
|
||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|