Files
dns-server/dns/cache.go
2026-01-17 01:18:03 +08:00

592 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dns
import (
"encoding/json"
"io/ioutil"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
// DNSCacheItem 表示缓存中的DNS响应项
type DNSCacheItem struct {
Response *dns.Msg // DNS响应消息
Expiry time.Time // 过期时间
HasDNSSEC bool // 是否包含DNSSEC记录
Size int // 缓存项大小(字节)
}
// SerializableDNSCacheItem 用于JSON序列化的缓存项结构
type SerializableDNSCacheItem struct {
ResponseBytes []byte `json:"responseBytes"` // 二进制DNS响应
Expiry int64 `json:"expiry"` // 过期时间(纳秒)
HasDNSSEC bool `json:"hasDNSSEC"` // 是否包含DNSSEC记录
Size int `json:"size"` // 缓存项大小(字节)
}
// SerializableDNSCache 可序列化的缓存结构
type SerializableDNSCache struct {
Items map[string]*SerializableDNSCacheItem `json:"items"` // 缓存项
TTL int64 `json:"ttl"` // 默认TTL纳秒
MaxSize int `json:"maxSize"` // 最大缓存大小
CacheMode string `json:"cacheMode"` // 缓存模式
CacheFilePath string `json:"cacheFilePath"` // 缓存文件路径
}
// DNSCache DNS缓存结构
type DNSCache struct {
cache map[string]*LRUNode // 缓存映射表,直接存储链表节点
mutex sync.RWMutex // 读写锁,保护缓存
ttl time.Duration // 默认缓存TTL
maxSize int // 最大缓存条目数
cacheSize int64 // 当前缓存大小(字节)
maxCacheSize int64 // 最大缓存大小(字节)
cacheMode string // 缓存模式
cacheFilePath string // 缓存文件路径
saveInterval time.Duration // 保存间隔
saveMutex sync.Mutex // 保存互斥锁
maxCacheTTL time.Duration // 最大缓存TTL
minCacheTTL time.Duration // 最小缓存TTL
saveStopCh chan struct{} // 保存循环停止通道
saveRunning bool // 保存循环是否运行
saveLoopMutex sync.Mutex // 保护保存循环状态的互斥锁
// 双向链表头和尾指针用于LRU淘汰
head *LRUNode // 头指针,指向最久未使用的节点
tail *LRUNode // 尾指针,指向最近使用的节点
}
// LRUNode 双向链表节点用于LRU缓存
type LRUNode struct {
key string
value *DNSCacheItem
prev *LRUNode
next *LRUNode
}
// NewDNSCache 创建新的DNS缓存实例
func NewDNSCache(defaultTTL time.Duration, cacheMode string, cacheSizeMB int, cacheFilePath string, saveInterval time.Duration, maxCacheTTL, minCacheTTL time.Duration) *DNSCache {
// 计算最大缓存大小(字节)
maxCacheSize := int64(cacheSizeMB) * 1024 * 1024
cache := &DNSCache{
cache: make(map[string]*LRUNode),
ttl: defaultTTL,
maxSize: 10000, // 默认最大缓存10000条记录
cacheSize: 0,
maxCacheSize: maxCacheSize,
cacheMode: cacheMode,
cacheFilePath: cacheFilePath,
saveInterval: saveInterval,
maxCacheTTL: maxCacheTTL,
minCacheTTL: minCacheTTL,
saveStopCh: make(chan struct{}),
saveRunning: false,
head: nil,
tail: nil,
}
// 加载现有缓存(如果存在)
if cacheMode == "file" {
cache.LoadFromFile()
}
// 启动缓存清理协程
go cache.startCleanupLoop()
// 启动定期保存协程(如果是文件缓存)
if cacheMode == "file" {
go cache.startSaveLoop()
}
return cache
}
// addNodeToTail 将节点添加到链表尾部(表示最近使用)
func (c *DNSCache) addNodeToTail(node *LRUNode) {
if c.tail == nil {
// 链表为空
c.head = node
c.tail = node
} else {
// 添加到尾部
node.prev = c.tail
c.tail.next = node
c.tail = node
}
}
// removeNode 从链表中移除指定节点
func (c *DNSCache) removeNode(node *LRUNode) {
if node.prev != nil {
node.prev.next = node.next
} else {
// 移除的是头节点
c.head = node.next
}
if node.next != nil {
node.next.prev = node.prev
} else {
// 移除的是尾节点
c.tail = node.prev
}
// 清空节点的前后指针
node.prev = nil
node.next = nil
}
// moveNodeToTail 将节点移动到链表尾部(表示最近使用)
func (c *DNSCache) moveNodeToTail(node *LRUNode) {
// 如果已经是尾节点,不需要移动
if node == c.tail {
return
}
// 从链表中移除节点
c.removeNode(node)
// 重新添加到尾部
c.addNodeToTail(node)
}
// cacheKey 生成缓存键
func cacheKey(qName string, qType uint16) string {
return qName + "|" + dns.TypeToString[qType]
}
// calculateItemSize 计算缓存项大小
func calculateItemSize(item *DNSCacheItem) int {
// 序列化响应以计算大小
data, err := json.Marshal(item)
if err != nil {
return 0
}
return len(data)
}
// hasDNSSECRecords 检查响应是否包含DNSSEC记录
func hasDNSSECRecords(response *dns.Msg) bool {
// 定义检查单个RR是否为DNSSEC记录的辅助函数
isDNSSECRecord := func(rr dns.RR) bool {
switch rr.(type) {
case *dns.DNSKEY, *dns.RRSIG, *dns.DS, *dns.NSEC, *dns.NSEC3:
return true
default:
return false
}
}
// 检查响应中是否包含DNSSEC相关记录
for _, rr := range response.Answer {
if isDNSSECRecord(rr) {
return true
}
}
for _, rr := range response.Ns {
if isDNSSECRecord(rr) {
return true
}
}
for _, rr := range response.Extra {
if isDNSSECRecord(rr) {
return true
}
}
return false
}
// Set 设置缓存项
func (c *DNSCache) Set(qName string, qType uint16, response *dns.Msg, ttl time.Duration) {
// 设置默认TTL
if ttl <= 0 {
ttl = c.ttl
}
// 应用maxCacheTTL和minCacheTTL约束
if c.maxCacheTTL > 0 && ttl > c.maxCacheTTL {
ttl = c.maxCacheTTL
}
if c.minCacheTTL > 0 && ttl < c.minCacheTTL {
ttl = c.minCacheTTL
}
key := cacheKey(qName, qType)
item := &DNSCacheItem{
Response: response.Copy(), // 复制响应以避免外部修改
Expiry: time.Now().Add(ttl),
HasDNSSEC: hasDNSSECRecords(response), // 检查并设置DNSSEC标志
}
// 计算缓存项大小
item.Size = calculateItemSize(item)
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果条目已存在,先从链表和缓存中移除,并更新缓存大小
if existingNode, found := c.cache[key]; found {
c.cacheSize -= int64(existingNode.value.Size)
c.removeNode(existingNode)
delete(c.cache, key)
}
// 创建新的链表节点并添加到尾部
newNode := &LRUNode{
key: key,
value: item,
}
c.addNodeToTail(newNode)
c.cache[key] = newNode
c.cacheSize += int64(item.Size)
// 检查是否超过最大条目数限制,如果超过则移除最久未使用的条目
if len(c.cache) > c.maxSize {
// 最久未使用的条目是链表的头节点
if c.head != nil {
c.cacheSize -= int64(c.head.value.Size)
oldestKey := c.head.key
// 从缓存和链表中移除头节点
delete(c.cache, oldestKey)
c.removeNode(c.head)
}
}
// 检查是否超过最大缓存大小,如果超过则继续移除最久未使用的条目
for c.cacheSize > c.maxCacheSize && c.head != nil {
c.cacheSize -= int64(c.head.value.Size)
oldestKey := c.head.key
delete(c.cache, oldestKey)
c.removeNode(c.head)
}
}
// Get 获取缓存项
func (c *DNSCache) Get(qName string, qType uint16) (*dns.Msg, bool) {
key := cacheKey(qName, qType)
// 首先使用读锁检查缓存项是否存在和是否过期
c.mutex.RLock()
node, found := c.cache[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
// 检查是否过期
if time.Now().After(node.value.Expiry) {
c.mutex.RUnlock()
// 需要删除过期条目,使用写锁
c.mutex.Lock()
// 再次检查,防止在读写锁切换期间被其他协程处理
if node, stillExists := c.cache[key]; stillExists && time.Now().After(node.value.Expiry) {
delete(c.cache, key)
c.removeNode(node)
}
c.mutex.Unlock()
return nil, false
}
// 返回前释放读锁,避免长时间持有锁
response := node.value.Response.Copy()
c.mutex.RUnlock()
// 标记为最近使用需要修改链表,使用写锁
c.mutex.Lock()
// 再次检查节点是否存在,防止在读写锁切换期间被删除
if node, stillExists := c.cache[key]; stillExists {
c.moveNodeToTail(node)
}
c.mutex.Unlock()
return response, true
}
// delete 删除缓存项
func (c *DNSCache) delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 从缓存和链表中删除
if node, found := c.cache[key]; found {
delete(c.cache, key)
c.removeNode(node)
}
}
// Clear 清空缓存
func (c *DNSCache) Clear() {
c.mutex.Lock()
c.cache = make(map[string]*LRUNode)
// 重置链表指针
c.head = nil
c.tail = nil
c.mutex.Unlock()
}
// Size 获取缓存大小
func (c *DNSCache) Size() int {
c.mutex.RLock()
defer c.mutex.RUnlock()
return len(c.cache)
}
// startCleanupLoop 启动定期清理过期缓存的协程
func (c *DNSCache) startCleanupLoop() {
ticker := time.NewTicker(time.Minute * 1) // 每1分钟清理一次减少内存占用
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// startSaveLoop 启动定期保存缓存的协程
func (c *DNSCache) startSaveLoop() {
c.saveLoopMutex.Lock()
// 如果已经在运行,直接返回
if c.saveRunning {
c.saveLoopMutex.Unlock()
return
}
// 重置停止通道
c.saveStopCh = make(chan struct{})
c.saveRunning = true
c.saveLoopMutex.Unlock()
go func() {
ticker := time.NewTicker(c.saveInterval) // 根据配置的间隔保存
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 检查缓存模式如果不是file模式则不保存
c.mutex.RLock()
mode := c.cacheMode
c.mutex.RUnlock()
if mode == "file" {
c.SaveToFile()
}
case <-c.saveStopCh:
// 停止保存循环
c.saveLoopMutex.Lock()
c.saveRunning = false
c.saveLoopMutex.Unlock()
return
}
}
}()
}
// saveCacheToFile 保存缓存到文件的底层实现,不检查缓存模式
func (c *DNSCache) saveCacheToFile() {
c.saveMutex.Lock()
defer c.saveMutex.Unlock()
c.mutex.RLock()
defer c.mutex.RUnlock()
// 收集有效的缓存项
validItems := make(map[string]*SerializableDNSCacheItem)
now := time.Now()
for key, node := range c.cache {
// 只保存未过期的缓存项
if now.Before(node.value.Expiry) {
// 序列化DNS响应为二进制
responseBytes, err := node.value.Response.Pack()
if err != nil {
continue // 跳过无法序列化的响应
}
// 创建可序列化的缓存项
validItems[key] = &SerializableDNSCacheItem{
ResponseBytes: responseBytes,
Expiry: node.value.Expiry.UnixNano(),
HasDNSSEC: node.value.HasDNSSEC,
Size: node.value.Size,
}
}
}
// 创建可序列化的缓存结构
serializableCache := &SerializableDNSCache{
Items: validItems,
TTL: int64(c.ttl),
MaxSize: c.maxSize,
CacheMode: c.cacheMode,
CacheFilePath: c.cacheFilePath,
}
// 序列化到JSON
data, err := json.MarshalIndent(serializableCache, "", " ")
if err != nil {
return
}
// 确保目录存在
os.MkdirAll(cacheDir(), 0755)
// 保存到文件
err = ioutil.WriteFile(c.cacheFilePath, data, 0644)
if err != nil {
return
}
}
// SaveToFile 保存缓存到文件
func (c *DNSCache) SaveToFile() {
// 检查缓存模式如果不是file模式直接返回
c.mutex.RLock()
mode := c.cacheMode
c.mutex.RUnlock()
if mode != "file" {
return
}
// 调用底层保存逻辑
c.saveCacheToFile()
}
// LoadFromFile 从文件加载缓存
func (c *DNSCache) LoadFromFile() {
c.mutex.Lock()
defer c.mutex.Unlock()
// 检查文件是否存在
if _, err := os.Stat(c.cacheFilePath); os.IsNotExist(err) {
return // 文件不存在,跳过加载
}
// 读取文件内容
data, err := ioutil.ReadFile(c.cacheFilePath)
if err != nil {
return
}
// 反序列化JSON
var serializableCache SerializableDNSCache
err = json.Unmarshal(data, &serializableCache)
if err != nil {
return
}
// 加载缓存项
now := time.Now()
for key, serializableItem := range serializableCache.Items {
// 转换过期时间
expiry := time.Unix(0, serializableItem.Expiry)
// 只加载未过期的缓存项
if now.Before(expiry) {
// 反序列化二进制DNS响应
response := &dns.Msg{}
err := response.Unpack(serializableItem.ResponseBytes)
if err != nil {
continue // 跳过无法反序列化的响应
}
// 创建缓存项
item := &DNSCacheItem{
Response: response,
Expiry: expiry,
HasDNSSEC: serializableItem.HasDNSSEC,
Size: serializableItem.Size,
}
// 创建新的链表节点并添加到尾部
newNode := &LRUNode{
key: key,
value: item,
}
c.addNodeToTail(newNode)
c.cache[key] = newNode
c.cacheSize += int64(item.Size)
}
}
}
// cacheDir 返回缓存目录
func cacheDir() string {
return "data"
}
// SetMaxCacheTTL 设置最大缓存TTL
func (c *DNSCache) SetMaxCacheTTL(ttl time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.maxCacheTTL = ttl
}
// SetMinCacheTTL 设置最小缓存TTL
func (c *DNSCache) SetMinCacheTTL(ttl time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.minCacheTTL = ttl
}
// SetCacheMode 设置缓存模式
func (c *DNSCache) SetCacheMode(mode string) {
c.mutex.Lock()
oldMode := c.cacheMode
c.mutex.Unlock()
// 根据模式变化决定是否启动或停止保存循环
if oldMode != mode {
if oldMode == "file" {
// 从file模式切换到其他模式先保存当前缓存到文件
// 直接调用底层保存逻辑,不检查缓存模式
c.saveCacheToFile()
}
c.mutex.Lock()
c.cacheMode = mode
c.mutex.Unlock()
if mode == "file" {
// 切换到file模式启动保存循环
c.startSaveLoop()
} else {
// 切换到非file模式停止保存循环
c.saveLoopMutex.Lock()
if c.saveRunning {
close(c.saveStopCh)
c.saveRunning = false
}
c.saveLoopMutex.Unlock()
}
}
}
// SetMaxCacheSize 设置最大缓存大小
func (c *DNSCache) SetMaxCacheSize(size int64) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.maxCacheSize = size
}
// cleanupExpired 清理过期的缓存项
func (c *DNSCache) cleanupExpired() {
now := time.Now()
c.mutex.Lock()
defer c.mutex.Unlock()
// 收集所有过期的键
var expiredKeys []string
for key, node := range c.cache {
if now.After(node.value.Expiry) {
expiredKeys = append(expiredKeys, key)
}
}
// 删除过期的缓存项
for _, key := range expiredKeys {
if node, found := c.cache[key]; found {
delete(c.cache, key)
c.removeNode(node)
}
}
}