增加数据持久化功能

This commit is contained in:
Alex Yang
2025-11-23 19:07:59 +08:00
parent 63a95f7463
commit f5911af449
8 changed files with 519 additions and 16 deletions

View File

@@ -2,8 +2,12 @@ package dns
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"sort"
"sync"
"time"
@@ -23,6 +27,15 @@ type BlockedDomain struct {
LastSeen time.Time
}
// StatsData 用于持久化的统计数据结构
type StatsData struct {
Stats *Stats `json:"stats"`
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
HourlyStats map[string]int64 `json:"hourlyStats"`
LastSaved time.Time `json:"lastSaved"`
}
// Server DNS服务器
type Server struct {
config *config.DNSConfig
@@ -40,6 +53,8 @@ type Server struct {
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
hourlyStatsMutex sync.RWMutex
hourlyStats map[string]int64 // 按小时统计屏蔽数量
saveTicker *time.Ticker // 用于定时保存数据
saveDone chan struct{} // 用于通知保存协程停止
}
// Stats DNS服务器统计信息
@@ -54,7 +69,7 @@ type Stats struct {
// NewServer 创建DNS服务器实例
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
ctx, cancel := context.WithCancel(context.Background())
return &Server{
server := &Server{
config: config,
shieldConfig: shieldConfig,
shieldManager: shieldManager,
@@ -73,7 +88,14 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
blockedDomains: make(map[string]*BlockedDomain),
resolvedDomains: make(map[string]*BlockedDomain),
hourlyStats: make(map[string]int64),
saveDone: make(chan struct{}),
}
// 加载已保存的统计数据
server.loadStatsData()
return server
}
// Start 启动DNS服务器
@@ -116,10 +138,17 @@ func (s *Server) Start() error {
// Stop 停止DNS服务器
func (s *Server) Stop() {
// 发送停止信号给保存协程
close(s.saveDone)
// 最后保存一次数据
s.saveStatsData()
// 停止服务器
s.cancel()
if s.server != nil {
s.server.Shutdown()
}
s.cancel()
logger.Info("DNS服务器已停止")
}
@@ -452,3 +481,135 @@ func (s *Server) GetHourlyStats() map[string]int64 {
}
return result
}
// loadStatsData 从文件加载统计数据
func (s *Server) loadStatsData() {
if s.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(s.config.StatsFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取统计数据文件失败", "error", err)
}
return
}
var statsData StatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析统计数据失败", "error", err)
return
}
// 恢复统计数据
s.statsMutex.Lock()
if statsData.Stats != nil {
s.stats = statsData.Stats
}
s.statsMutex.Unlock()
s.blockedDomainsMutex.Lock()
if statsData.BlockedDomains != nil {
s.blockedDomains = statsData.BlockedDomains
}
s.blockedDomainsMutex.Unlock()
s.resolvedDomainsMutex.Lock()
if statsData.ResolvedDomains != nil {
s.resolvedDomains = statsData.ResolvedDomains
}
s.resolvedDomainsMutex.Unlock()
s.hourlyStatsMutex.Lock()
if statsData.HourlyStats != nil {
s.hourlyStats = statsData.HourlyStats
}
s.hourlyStatsMutex.Unlock()
logger.Info("统计数据加载成功")
}
// saveStatsData 保存统计数据到文件
func (s *Server) saveStatsData() {
if s.config.StatsFile == "" {
return
}
// 创建数据目录
statsDir := filepath.Dir(s.config.StatsFile)
err := os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建统计数据目录失败", "error", err)
return
}
// 收集所有统计数据
statsData := &StatsData{
Stats: s.GetStats(),
LastSaved: time.Now(),
}
// 复制域名数据
s.blockedDomainsMutex.RLock()
statsData.BlockedDomains = make(map[string]*BlockedDomain)
for k, v := range s.blockedDomains {
statsData.BlockedDomains[k] = v
}
s.blockedDomainsMutex.RUnlock()
s.resolvedDomainsMutex.RLock()
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
for k, v := range s.resolvedDomains {
statsData.ResolvedDomains[k] = v
}
s.resolvedDomainsMutex.RUnlock()
s.hourlyStatsMutex.RLock()
statsData.HourlyStats = make(map[string]int64)
for k, v := range s.hourlyStats {
statsData.HourlyStats[k] = v
}
s.hourlyStatsMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化统计数据失败", "error", err)
return
}
// 写入文件
err = ioutil.WriteFile(s.config.StatsFile, jsonData, 0644)
if err != nil {
logger.Error("保存统计数据到文件失败", "error", err)
return
}
logger.Info("统计数据保存成功")
}
// startAutoSave 启动自动保存功能
func (s *Server) startAutoSave() {
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
return
}
// 设置定时器
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
defer s.saveTicker.Stop()
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
// 定期保存数据
for {
select {
case <-s.saveTicker.C:
s.saveStatsData()
case <-s.saveDone:
return
}
}
}