增加数据持久化功能
This commit is contained in:
@@ -3,10 +3,12 @@ package shield
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -23,6 +25,13 @@ type regexRule struct {
|
||||
original string
|
||||
}
|
||||
|
||||
// ShieldStatsData 用于持久化的Shield统计数据
|
||||
type ShieldStatsData struct {
|
||||
BlockedDomainsCount map[string]int `json:"blockedDomainsCount"`
|
||||
ResolvedDomainsCount map[string]int `json:"resolvedDomainsCount"`
|
||||
LastSaved time.Time `json:"lastSaved"`
|
||||
}
|
||||
|
||||
// ShieldManager 屏蔽管理器
|
||||
type ShieldManager struct {
|
||||
config *config.ShieldConfig
|
||||
@@ -42,7 +51,7 @@ type ShieldManager struct {
|
||||
// NewShieldManager 创建屏蔽管理器实例
|
||||
func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &ShieldManager{
|
||||
manager := &ShieldManager{
|
||||
config: config,
|
||||
domainRules: make(map[string]bool),
|
||||
domainExceptions: make(map[string]bool),
|
||||
@@ -54,6 +63,11 @@ func NewShieldManager(config *config.ShieldConfig) *ShieldManager {
|
||||
updateCtx: ctx,
|
||||
updateCancel: cancel,
|
||||
}
|
||||
|
||||
// 加载已保存的计数数据
|
||||
manager.loadStatsData()
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// LoadRules 加载屏蔽规则
|
||||
@@ -651,6 +665,9 @@ func (m *ShieldManager) StartAutoUpdate() {
|
||||
ticker := time.NewTicker(time.Duration(m.config.UpdateInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动自动保存计数数据
|
||||
go m.startAutoSaveStats()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
@@ -661,16 +678,27 @@ func (m *ShieldManager) StartAutoUpdate() {
|
||||
logger.Info("自动更新规则成功")
|
||||
}
|
||||
case <-m.updateCtx.Done():
|
||||
// 保存计数数据
|
||||
m.saveStatsData()
|
||||
m.updateRunning = false
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("规则自动更新已启动", "interval", m.config.UpdateInterval)
|
||||
|
||||
// 如果是首次启动,先保存一次数据确保目录存在
|
||||
go m.saveStatsData()
|
||||
}
|
||||
|
||||
// StopAutoUpdate 停止自动更新
|
||||
func (m *ShieldManager) StopAutoUpdate() {
|
||||
m.updateRunning = false
|
||||
m.updateCancel()
|
||||
// 保存计数数据
|
||||
m.saveStatsData()
|
||||
logger.Info("规则自动更新已停止")
|
||||
}
|
||||
|
||||
// saveRulesToFile 保存规则到文件
|
||||
@@ -781,7 +809,112 @@ func (m *ShieldManager) GetStats() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// GetRules 获取所有规则的详细列表
|
||||
// loadStatsData 从文件加载计数数据
|
||||
func (m *ShieldManager) loadStatsData() {
|
||||
if m.config.StatsFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
data, err := ioutil.ReadFile(m.config.StatsFile)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
logger.Error("读取Shield计数数据文件失败", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statsData ShieldStatsData
|
||||
err = json.Unmarshal(data, &statsData)
|
||||
if err != nil {
|
||||
logger.Error("解析Shield计数数据失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 恢复计数数据
|
||||
m.rulesMutex.Lock()
|
||||
if statsData.BlockedDomainsCount != nil {
|
||||
m.blockedDomainsCount = statsData.BlockedDomainsCount
|
||||
}
|
||||
if statsData.ResolvedDomainsCount != nil {
|
||||
m.resolvedDomainsCount = statsData.ResolvedDomainsCount
|
||||
}
|
||||
m.rulesMutex.Unlock()
|
||||
|
||||
logger.Info("Shield计数数据加载成功")
|
||||
}
|
||||
|
||||
// saveStatsData 保存计数数据到文件
|
||||
func (m *ShieldManager) saveStatsData() {
|
||||
if m.config.StatsFile == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建数据目录
|
||||
statsDir := filepath.Dir(m.config.StatsFile)
|
||||
err := os.MkdirAll(statsDir, 0755)
|
||||
if err != nil {
|
||||
logger.Error("创建Shield统计数据目录失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 收集计数数据
|
||||
m.rulesMutex.RLock()
|
||||
statsData := &ShieldStatsData{
|
||||
BlockedDomainsCount: make(map[string]int),
|
||||
ResolvedDomainsCount: make(map[string]int),
|
||||
LastSaved: time.Now(),
|
||||
}
|
||||
|
||||
// 复制数据
|
||||
for k, v := range m.blockedDomainsCount {
|
||||
statsData.BlockedDomainsCount[k] = v
|
||||
}
|
||||
for k, v := range m.resolvedDomainsCount {
|
||||
statsData.ResolvedDomainsCount[k] = v
|
||||
}
|
||||
m.rulesMutex.RUnlock()
|
||||
|
||||
// 序列化数据
|
||||
jsonData, err := json.MarshalIndent(statsData, "", " ")
|
||||
if err != nil {
|
||||
logger.Error("序列化Shield计数数据失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
err = ioutil.WriteFile(m.config.StatsFile, jsonData, 0644)
|
||||
if err != nil {
|
||||
logger.Error("保存Shield计数数据到文件失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Shield计数数据保存成功")
|
||||
}
|
||||
|
||||
// startAutoSaveStats 启动计数数据自动保存功能
|
||||
func (m *ShieldManager) startAutoSaveStats() {
|
||||
if m.config.StatsFile == "" || m.config.StatsSaveInterval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Duration(m.config.StatsSaveInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
logger.Info("启动Shield计数数据自动保存功能", "interval", m.config.StatsSaveInterval, "file", m.config.StatsFile)
|
||||
|
||||
// 定期保存数据
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.saveStatsData()
|
||||
case <-m.updateCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRules 获取所有规则
|
||||
func (m *ShieldManager) GetRules() map[string]interface{} {
|
||||
m.rulesMutex.RLock()
|
||||
defer m.rulesMutex.RUnlock()
|
||||
|
||||
Reference in New Issue
Block a user