增加数据持久化功能

This commit is contained in:
Alex Yang
2025-11-23 19:07:59 +08:00
parent 63a95f7463
commit f5911af449
8 changed files with 519 additions and 16 deletions

View File

@@ -3,10 +3,12 @@ package shield
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
@@ -23,6 +25,13 @@ type regexRule struct {
original string
}
// ShieldStatsData 用于持久化的Shield统计数据
type ShieldStatsData struct {
BlockedDomainsCount map[string]int `json:"blockedDomainsCount"`
ResolvedDomainsCount map[string]int `json:"resolvedDomainsCount"`
LastSaved time.Time `json:"lastSaved"`
}
// ShieldManager 屏蔽管理器
type ShieldManager struct {
config *config.ShieldConfig
@@ -42,7 +51,7 @@ type ShieldManager struct {
// NewShieldManager 创建屏蔽管理器实例
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
ctx, cancel := context.WithCancel(context.Background())
return &ShieldManager{
manager := &ShieldManager{
config: config,
domainRules: make(map[string]bool),
domainExceptions: make(map[string]bool),
@@ -54,6 +63,11 @@ func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
updateCtx: ctx,
updateCancel: cancel,
}
// 加载已保存的计数数据
manager.loadStatsData()
return manager
}
// LoadRules 加载屏蔽规则
@@ -651,6 +665,9 @@ func (m *ShieldManager) StartAutoUpdate() {
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
defer ticker.Stop()
// 启动自动保存计数数据
go m.startAutoSaveStats()
for {
select {
case <-ticker.C:
@@ -661,16 +678,27 @@ func (m *ShieldManager) StartAutoUpdate() {
logger.Info("自动更新规则成功")
}
case <-m.updateCtx.Done():
// 保存计数数据
m.saveStatsData()
m.updateRunning = false
return
}
}
}()
logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval)
// 如果是首次启动,先保存一次数据确保目录存在
go m.saveStatsData()
}
// StopAutoUpdate 停止自动更新
func (m *ShieldManager) StopAutoUpdate() {
m.updateRunning = false
m.updateCancel()
// 保存计数数据
m.saveStatsData()
logger.Info("规则自动更新已停止")
}
// saveRulesToFile 保存规则到文件
@@ -781,7 +809,112 @@ func (m *ShieldManager) GetStats() map[string]interface{} {
}
}
// GetRules 获取所有规则的详细列表
// loadStatsData 从文件加载计数数据
func (m *ShieldManager) loadStatsData() {
if m.config.StatsFile == "" {
return
}
// 检查文件是否存在
data, err := ioutil.ReadFile(m.config.StatsFile)
if err != nil {
if !os.IsNotExist(err) {
logger.Error("读取Shield计数数据文件失败", "error", err)
}
return
}
var statsData ShieldStatsData
err = json.Unmarshal(data, &statsData)
if err != nil {
logger.Error("解析Shield计数数据失败", "error", err)
return
}
// 恢复计数数据
m.rulesMutex.Lock()
if statsData.BlockedDomainsCount != nil {
m.blockedDomainsCount = statsData.BlockedDomainsCount
}
if statsData.ResolvedDomainsCount != nil {
m.resolvedDomainsCount = statsData.ResolvedDomainsCount
}
m.rulesMutex.Unlock()
logger.Info("Shield计数数据加载成功")
}
// saveStatsData 保存计数数据到文件
func (m *ShieldManager) saveStatsData() {
if m.config.StatsFile == "" {
return
}
// 创建数据目录
statsDir := filepath.Dir(m.config.StatsFile)
err := os.MkdirAll(statsDir, 0755)
if err != nil {
logger.Error("创建Shield统计数据目录失败", "error", err)
return
}
// 收集计数数据
m.rulesMutex.RLock()
statsData := &ShieldStatsData{
BlockedDomainsCount: make(map[string]int),
ResolvedDomainsCount: make(map[string]int),
LastSaved: time.Now(),
}
// 复制数据
for k, v := range m.blockedDomainsCount {
statsData.BlockedDomainsCount[k] = v
}
for k, v := range m.resolvedDomainsCount {
statsData.ResolvedDomainsCount[k] = v
}
m.rulesMutex.RUnlock()
// 序列化数据
jsonData, err := json.MarshalIndent(statsData, "", " ")
if err != nil {
logger.Error("序列化Shield计数数据失败", "error", err)
return
}
// 写入文件
err = ioutil.WriteFile(m.config.StatsFile, jsonData, 0644)
if err != nil {
logger.Error("保存Shield计数数据到文件失败", "error", err)
return
}
logger.Info("Shield计数数据保存成功")
}
// startAutoSaveStats 启动计数数据自动保存功能
func (m *ShieldManager) startAutoSaveStats() {
if m.config.StatsFile == "" || m.config.StatsSaveInterval <= 0 {
return
}
ticker := time.NewTicker(time.Duration(m.config.StatsSaveInterval) * time.Second)
defer ticker.Stop()
logger.Info("启动Shield计数数据自动保存功能", "interval", m.config.StatsSaveInterval, "file", m.config.StatsFile)
// 定期保存数据
for {
select {
case <-ticker.C:
m.saveStatsData()
case <-m.updateCtx.Done():
return
}
}
}
// GetRules 获取所有规则
func (m *ShieldManager) GetRules() map[string]interface{} {
m.rulesMutex.RLock()
defer m.rulesMutex.RUnlock()