增加数据持久化功能
This commit is contained in:
165
dns/server.go
165
dns/server.go
@@ -2,8 +2,12 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -23,6 +27,15 @@ type BlockedDomain struct {
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// StatsData 用于持久化的统计数据结构
|
||||
type StatsData struct {
|
||||
Stats *Stats `json:"stats"`
|
||||
BlockedDomains map[string]*BlockedDomain `json:"blockedDomains"`
|
||||
ResolvedDomains map[string]*BlockedDomain `json:"resolvedDomains"`
|
||||
HourlyStats map[string]int64 `json:"hourlyStats"`
|
||||
LastSaved time.Time `json:"lastSaved"`
|
||||
}
|
||||
|
||||
// Server DNS服务器
|
||||
type Server struct {
|
||||
config *config.DNSConfig
|
||||
@@ -40,6 +53,8 @@ type Server struct {
|
||||
resolvedDomains map[string]*BlockedDomain // 用于记录解析的域名
|
||||
hourlyStatsMutex sync.RWMutex
|
||||
hourlyStats map[string]int64 // 按小时统计屏蔽数量
|
||||
saveTicker *time.Ticker // 用于定时保存数据
|
||||
saveDone chan struct{} // 用于通知保存协程停止
|
||||
}
|
||||
|
||||
// Stats DNS服务器统计信息
|
||||
@@ -54,7 +69,7 @@ type Stats struct {
|
||||
// NewServer 创建DNS服务器实例
|
||||
func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shieldManager *shield.ShieldManager) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
server := &Server{
|
||||
config: config,
|
||||
shieldConfig: shieldConfig,
|
||||
shieldManager: shieldManager,
|
||||
@@ -73,7 +88,14 @@ func NewServer(config *config.DNSConfig, shieldConfig *config.ShieldConfig, shie
|
||||
blockedDomains: make(map[string]*BlockedDomain),
|
||||
resolvedDomains: make(map[string]*BlockedDomain),
|
||||
hourlyStats: make(map[string]int64),
|
||||
saveDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 加载已保存的统计数据
|
||||
server.loadStatsData()
|
||||
|
||||
return server
|
||||
|
||||
}
|
||||
|
||||
// Start 启动DNS服务器
|
||||
@@ -116,10 +138,17 @@ func (s *Server) Start() error {
|
||||
|
||||
// Stop 停止DNS服务器
|
||||
func (s *Server) Stop() {
|
||||
// 发送停止信号给保存协程
|
||||
close(s.saveDone)
|
||||
|
||||
// 最后保存一次数据
|
||||
s.saveStatsData()
|
||||
|
||||
// 停止服务器
|
||||
s.cancel()
|
||||
if s.server != nil {
|
||||
s.server.Shutdown()
|
||||
}
|
||||
s.cancel()
|
||||
logger.Info("DNS服务器已停止")
|
||||
}
|
||||
|
||||
@@ -452,3 +481,135 @@ func (s *Server) GetHourlyStats() map[string]int64 {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// loadStatsData 从文件加载统计数据
|
||||
func (s *Server) loadStatsData() {
|
||||
if s.config.StatsFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
data, err := ioutil.ReadFile(s.config.StatsFile)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error("读取统计数据文件失败", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statsData StatsData
|
||||
err = json.Unmarshal(data, &statsData)
|
||||
if err != nil {
|
||||
logger.Error("解析统计数据失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 恢复统计数据
|
||||
s.statsMutex.Lock()
|
||||
if statsData.Stats != nil {
|
||||
s.stats = statsData.Stats
|
||||
}
|
||||
s.statsMutex.Unlock()
|
||||
|
||||
s.blockedDomainsMutex.Lock()
|
||||
if statsData.BlockedDomains != nil {
|
||||
s.blockedDomains = statsData.BlockedDomains
|
||||
}
|
||||
s.blockedDomainsMutex.Unlock()
|
||||
|
||||
s.resolvedDomainsMutex.Lock()
|
||||
if statsData.ResolvedDomains != nil {
|
||||
s.resolvedDomains = statsData.ResolvedDomains
|
||||
}
|
||||
s.resolvedDomainsMutex.Unlock()
|
||||
|
||||
s.hourlyStatsMutex.Lock()
|
||||
if statsData.HourlyStats != nil {
|
||||
s.hourlyStats = statsData.HourlyStats
|
||||
}
|
||||
s.hourlyStatsMutex.Unlock()
|
||||
|
||||
logger.Info("统计数据加载成功")
|
||||
}
|
||||
|
||||
// saveStatsData 保存统计数据到文件
|
||||
func (s *Server) saveStatsData() {
|
||||
if s.config.StatsFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建数据目录
|
||||
statsDir := filepath.Dir(s.config.StatsFile)
|
||||
err := os.MkdirAll(statsDir, 0755)
|
||||
if err != nil {
|
||||
logger.Error("创建统计数据目录失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 收集所有统计数据
|
||||
statsData := &StatsData{
|
||||
Stats: s.GetStats(),
|
||||
LastSaved: time.Now(),
|
||||
}
|
||||
|
||||
// 复制域名数据
|
||||
s.blockedDomainsMutex.RLock()
|
||||
statsData.BlockedDomains = make(map[string]*BlockedDomain)
|
||||
for k, v := range s.blockedDomains {
|
||||
statsData.BlockedDomains[k] = v
|
||||
}
|
||||
s.blockedDomainsMutex.RUnlock()
|
||||
|
||||
s.resolvedDomainsMutex.RLock()
|
||||
statsData.ResolvedDomains = make(map[string]*BlockedDomain)
|
||||
for k, v := range s.resolvedDomains {
|
||||
statsData.ResolvedDomains[k] = v
|
||||
}
|
||||
s.resolvedDomainsMutex.RUnlock()
|
||||
|
||||
s.hourlyStatsMutex.RLock()
|
||||
statsData.HourlyStats = make(map[string]int64)
|
||||
for k, v := range s.hourlyStats {
|
||||
statsData.HourlyStats[k] = v
|
||||
}
|
||||
s.hourlyStatsMutex.RUnlock()
|
||||
|
||||
// 序列化数据
|
||||
jsonData, err := json.MarshalIndent(statsData, "", " ")
|
||||
if err != nil {
|
||||
logger.Error("序列化统计数据失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
err = ioutil.WriteFile(s.config.StatsFile, jsonData, 0644)
|
||||
if err != nil {
|
||||
logger.Error("保存统计数据到文件失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("统计数据保存成功")
|
||||
}
|
||||
|
||||
// startAutoSave 启动自动保存功能
|
||||
func (s *Server) startAutoSave() {
|
||||
if s.config.StatsFile == "" || s.config.SaveInterval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置定时器
|
||||
s.saveTicker = time.NewTicker(time.Duration(s.config.SaveInterval) * time.Second)
|
||||
defer s.saveTicker.Stop()
|
||||
|
||||
logger.Info("启动统计数据自动保存功能", "interval", s.config.SaveInterval, "file", s.config.StatsFile)
|
||||
|
||||
// 定期保存数据
|
||||
for {
|
||||
select {
|
||||
case <-s.saveTicker.C:
|
||||
s.saveStatsData()
|
||||
case <-s.saveDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user