This commit is contained in:
Alex Yang
2026-04-01 12:22:55 +08:00
parent 61789061ce
commit efebce3c39
46 changed files with 4797716 additions and 462145 deletions
+657 -54
View File
@@ -19,9 +19,294 @@ import (
"dns-server/logger"
"dns-server/shield"
"gopkg.in/ini.v1"
"dns-server/threat"
"github.com/gorilla/websocket"
)
// CacheEntry 缓存条目
type CacheEntry struct {
data interface{} // 缓存的数据
timestamp time.Time // 缓存创建时间
hits int // 命中次数(用于 LRU
}
// QueryCache 查询结果缓存
type QueryCache struct {
data map[string]*CacheEntry // 缓存键 -> 缓存条目
mutex sync.RWMutex // 读写锁
maxSize int // 最大缓存条目数
ttl time.Duration // 缓存有效期
}
// StatsCache 统计数据缓存
type StatsCache struct {
data map[string]*CacheEntry // 缓存键 -> 缓存条目
mutex sync.RWMutex // 读写锁
maxSize int // 最大缓存条目数
ttl time.Duration // 缓存有效期
lastStats map[string]interface{} // 上次缓存的统计数据,用于增量更新
}
// NewQueryCache 创建查询结果缓存
func NewQueryCache(maxSize int, ttl time.Duration) *QueryCache {
cache := &QueryCache{
data: make(map[string]*CacheEntry),
maxSize: maxSize,
ttl: ttl,
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// NewStatsCache 创建统计数据缓存
func NewStatsCache(maxSize int, ttl time.Duration) *StatsCache {
cache := &StatsCache{
data: make(map[string]*CacheEntry),
maxSize: maxSize,
ttl: ttl,
lastStats: make(map[string]interface{}),
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// Get 获取缓存条目
func (c *QueryCache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
entry, found := c.data[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
// 检查是否过期
if time.Since(entry.timestamp) > c.ttl {
// 过期,删除
c.mutex.Lock()
delete(c.data, key)
c.mutex.Unlock()
return nil, false
}
// 更新命中次数
c.mutex.Lock()
entry.hits++
c.mutex.Unlock()
return entry.data, true
}
// Set 设置缓存条目
func (c *QueryCache) Set(key string, data interface{}) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果缓存已满,删除最少使用的条目
if len(c.data) >= c.maxSize {
c.evictLRU()
}
c.data[key] = &CacheEntry{
data: data,
timestamp: time.Now(),
hits: 0,
}
}
// Delete 删除缓存条目
func (c *QueryCache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.data, key)
}
// Clear 清空缓存
func (c *QueryCache) Clear() {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data = make(map[string]*CacheEntry)
}
// evictLRU 淘汰最少使用的条目
func (c *QueryCache) evictLRU() {
var lruKey string
minHits := int(^uint(0) >> 1) // 最大 int 值
for key, entry := range c.data {
if entry.hits < minHits {
minHits = entry.hits
lruKey = key
}
}
if lruKey != "" {
delete(c.data, lruKey)
}
}
// startCleanupLoop 启动清理协程
func (c *QueryCache) startCleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期条目
func (c *QueryCache) cleanupExpired() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, entry := range c.data {
if now.Sub(entry.timestamp) > c.ttl {
delete(c.data, key)
}
}
}
// StatsCache 方法
// Get 获取统计数据缓存条目
func (c *StatsCache) Get(key string) (map[string]interface{}, bool) {
c.mutex.RLock()
entry, found := c.data[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
// 检查是否过期
if time.Since(entry.timestamp) > c.ttl {
// 过期,删除
c.mutex.Lock()
delete(c.data, key)
c.mutex.Unlock()
return nil, false
}
// 更新命中次数
c.mutex.Lock()
entry.hits++
c.mutex.Unlock()
if data, ok := entry.data.(map[string]interface{}); ok {
return data, true
}
return nil, false
}
// Set 设置统计数据缓存条目
func (c *StatsCache) Set(key string, data map[string]interface{}) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果缓存已满,删除最少使用的条目
if len(c.data) >= c.maxSize {
c.evictLRU()
}
c.data[key] = &CacheEntry{
data: data,
timestamp: time.Now(),
hits: 0,
}
// 保存最后统计数据
c.lastStats = data
}
// GetLastStats 获取上次缓存的统计数据
func (c *StatsCache) GetLastStats() map[string]interface{} {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.lastStats
}
// Delete 删除统计数据缓存条目
func (c *StatsCache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.data, key)
}
// Clear 清空统计数据缓存
func (c *StatsCache) Clear() {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data = make(map[string]*CacheEntry)
c.lastStats = make(map[string]interface{})
}
// evictLRU 淘汰最少使用的条目
func (c *StatsCache) evictLRU() {
var lruKey string
minHits := int(^uint(0) >> 1) // 最大 int 值
for key, entry := range c.data {
if entry.hits < minHits {
minHits = entry.hits
lruKey = key
}
}
if lruKey != "" {
delete(c.data, lruKey)
}
}
// startCleanupLoop 启动清理协程
func (c *StatsCache) startCleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期条目
func (c *StatsCache) cleanupExpired() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, entry := range c.data {
if now.Sub(entry.timestamp) > c.ttl {
delete(c.data, key)
}
}
}
// ClearQueryCache 清除查询缓存
func (s *Server) ClearQueryCache() {
if s.queryCache != nil {
s.queryCache.Clear()
}
}
// ClearStatsCache 清除统计缓存
func (s *Server) ClearStatsCache() {
if s.statsCache != nil {
s.statsCache.Clear()
}
}
// ClearAllCache 清除所有缓存
func (s *Server) ClearAllCache() {
s.ClearQueryCache()
s.ClearStatsCache()
}
// Server HTTP控制台服务器
type Server struct {
globalConfig *config.Config
@@ -36,11 +321,18 @@ type Server struct {
sessionsMutex sync.Mutex // 会话映射的互斥锁
sessionTTL time.Duration // 会话过期时间
// WebSocket相关字段
// WebSocket 相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
// 查询缓存相关字段
queryCache *QueryCache // 查询结果缓存
statsCache *StatsCache // 统计数据缓存
cacheEnabled bool // 缓存是否启用
cacheTTL time.Duration // 缓存过期时间
cacheMaxSize int // 缓存最大条目数
}
// NewServer 创建HTTP服务器实例
@@ -63,7 +355,13 @@ func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager
broadcastChan: make(chan []byte, 100),
// 会话管理初始化
sessions: make(map[string]time.Time),
sessionTTL: 24 * time.Hour, // 会话有效期24小时
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
// 查询缓存初始化
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
cacheEnabled: true, // 默认启用缓存
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
cacheMaxSize: 100, // 默认最大 100 条
}
// 启动广播协程
@@ -150,6 +448,12 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
// 威胁批量查询接口
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
// 威胁告警接口
mux.HandleFunc("/api/alert", s.loginRequired(s.handleAlert))
// 威胁告警解决接口
mux.HandleFunc("/api/alert/resolve", s.loginRequired(s.handleAlertResolve))
// 威胁域名管理接口
mux.HandleFunc("/api/threat/domain", s.loginRequired(s.handleThreatDomain))
// WebSocket 端点
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
@@ -961,7 +1265,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
@@ -991,7 +1295,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
@@ -1052,7 +1356,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
@@ -1093,7 +1397,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
@@ -1280,11 +1584,78 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
// saveConfigToFile 保存配置到文件
func saveConfigToFile(config *config.Config, filePath string) error {
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
// 创建新的INI文件
cfg := ini.Empty()
// DNS配置
dnsSection := cfg.Section("dns")
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
// 域名特定DNS服务器配置
for domain, servers := range config.DNS.DomainSpecificDNS {
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
}
return os.WriteFile(filePath, data, 0644)
// HTTP配置
httpSection := cfg.Section("http")
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
httpSection.Key("host").SetValue(config.HTTP.Host)
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
httpSection.Key("username").SetValue(config.HTTP.Username)
httpSection.Key("password").SetValue(config.HTTP.Password)
// Shield配置
shieldSection := cfg.Section("shield")
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
// 黑名单配置
for _, bl := range config.Shield.Blacklists {
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
}
// GFWList配置
gfwListSection := cfg.Section("gfwList")
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
gfwListSection.Key("content").SetValue(config.GFWList.Content)
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
// Log配置
logSection := cfg.Section("log")
logSection.Key("level").SetValue(config.Log.Level)
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
// Threat配置
threatSection := cfg.Section("threat")
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ","))
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ","))
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
// 保存到文件
return cfg.SaveTo(filePath)
}
// handleConfig 处理配置请求
@@ -1293,35 +1664,52 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
// 每次从配置文件重新读取最新配置
cfg, err := config.LoadConfig("config.ini")
if err != nil {
logger.Error("加载配置文件失败", "error", err)
// 如果加载失败,返回内存中的配置
cfg = s.globalConfig
}
// 返回当前配置(包括黑名单配置)
// 注意:key 名必须与前端期望的一致
config := map[string]interface{}{
"Shield": map[string]interface{}{
"blockMethod": s.globalConfig.Shield.BlockMethod,
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
"blacklists": s.globalConfig.Shield.Blacklists,
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": cfg.Shield.BlockMethod,
"customBlockIP": cfg.Shield.CustomBlockIP,
"blacklists": cfg.Shield.Blacklists,
"updateInterval": cfg.Shield.UpdateInterval,
"statsSaveInterval": cfg.Shield.StatsSaveInterval,
},
"GFWList": map[string]interface{}{
"ip": s.globalConfig.GFWList.IP,
"content": s.globalConfig.GFWList.Content,
"ip": cfg.GFWList.IP,
"content": cfg.GFWList.Content,
"enabled": cfg.GFWList.Enabled,
},
"DNSServer": map[string]interface{}{
"port": s.globalConfig.DNS.Port,
"QueryMode": s.globalConfig.DNS.QueryMode,
"UpstreamServers": s.globalConfig.DNS.UpstreamDNS,
"DNSSECUpstreamServers": s.globalConfig.DNS.DNSSECUpstreamDNS,
"saveInterval": s.globalConfig.DNS.SaveInterval,
"enableIPv6": s.globalConfig.DNS.EnableIPv6,
"CacheMode": s.globalConfig.DNS.CacheMode,
"CacheSize": s.globalConfig.DNS.CacheSize,
"MaxCacheTTL": s.globalConfig.DNS.MaxCacheTTL,
"MinCacheTTL": s.globalConfig.DNS.MinCacheTTL,
"enableFastReturn": s.globalConfig.DNS.EnableFastReturn,
"domainSpecificDNS": s.globalConfig.DNS.DomainSpecificDNS,
"noDNSSECDomains": s.globalConfig.DNS.NoDNSSECDomains,
"port": cfg.DNS.Port,
"QueryMode": cfg.DNS.QueryMode,
"UpstreamServers": cfg.DNS.UpstreamDNS,
"DNSSECUpstreamServers": cfg.DNS.DNSSECUpstreamDNS,
"saveInterval": cfg.DNS.SaveInterval,
"queryTimeout": cfg.DNS.QueryTimeout,
"enableIPv6": cfg.DNS.EnableIPv6,
"enableDNSSEC": cfg.DNS.EnableDNSSEC,
"enableFastReturn": cfg.DNS.EnableFastReturn,
"noDNSSECDomains": cfg.DNS.NoDNSSECDomains,
"CacheMode": cfg.DNS.CacheMode,
"CacheSize": cfg.DNS.CacheSize,
"MaxCacheTTL": cfg.DNS.MaxCacheTTL,
"MinCacheTTL": cfg.DNS.MinCacheTTL,
"domainSpecificDNS": cfg.DNS.DomainSpecificDNS,
},
"HTTPServer": map[string]interface{}{
"port": s.globalConfig.HTTP.Port,
"port": cfg.HTTP.Port,
"host": cfg.HTTP.Host,
"enableAPI": cfg.HTTP.EnableAPI,
"username": cfg.HTTP.Username,
"password": cfg.HTTP.Password,
},
}
json.NewEncoder(w).Encode(config)
@@ -1337,26 +1725,33 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
Timeout int `json:"timeout"`
SaveInterval int `json:"saveInterval"`
EnableIPv6 bool `json:"enableIPv6"`
EnableDNSSEC bool `json:"enableDNSSEC"`
EnableFastReturn *bool `json:"enableFastReturn"`
NoDNSSECDomains []string `json:"noDNSSECDomains"`
CacheMode string `json:"cacheMode"`
CacheSize int `json:"cacheSize"`
MaxCacheTTL int `json:"maxCacheTTL"`
MinCacheTTL int `json:"minCacheTTL"`
EnableFastReturn *bool `json:"enableFastReturn"`
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
NoDNSSECDomains []string `json:"noDNSSECDomains"`
} `json:"dnsserver"`
HTTPServer struct {
Port int `json:"port"`
Port int `json:"port"`
Host string `json:"host"`
EnableAPI bool `json:"enableAPI"`
Username string `json:"username"`
Password string `json:"password"`
} `json:"httpserver"`
Shield struct {
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
Blacklists []config.BlacklistEntry `json:"blacklists"`
UpdateInterval int `json:"updateInterval"`
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
Blacklists []config.BlacklistEntry `json:"blacklists"`
UpdateInterval int `json:"updateInterval"`
StatsSaveInterval int `json:"statsSaveInterval"`
} `json:"shield"`
GFWList struct {
IP string `json:"ip"`
Content string `json:"content"`
Enabled bool `json:"enabled"`
} `json:"gfwList"`
}
@@ -1378,7 +1773,11 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
if req.DNSServer.SaveInterval > 0 {
s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval
}
if req.DNSServer.Timeout > 0 {
s.globalConfig.DNS.QueryTimeout = req.DNSServer.Timeout
}
s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6
s.globalConfig.DNS.EnableDNSSEC = req.DNSServer.EnableDNSSEC
// 更新查询模式
if req.DNSServer.QueryMode != "" {
s.globalConfig.DNS.QueryMode = req.DNSServer.QueryMode
@@ -1400,19 +1799,29 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
if req.DNSServer.EnableFastReturn != nil {
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
}
// 更新domainSpecificDNS
if req.DNSServer.DomainSpecificDNS != nil {
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
}
// 更新noDNSSECDomains
if len(req.DNSServer.NoDNSSECDomains) > 0 {
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
}
// 更新domainSpecificDNS
if req.DNSServer.DomainSpecificDNS != nil {
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
}
// 更新HTTP配置
if req.HTTPServer.Port > 0 {
s.globalConfig.HTTP.Port = req.HTTPServer.Port
}
if req.HTTPServer.Host != "" {
s.globalConfig.HTTP.Host = req.HTTPServer.Host
}
s.globalConfig.HTTP.EnableAPI = req.HTTPServer.EnableAPI
if req.HTTPServer.Username != "" {
s.globalConfig.HTTP.Username = req.HTTPServer.Username
}
if req.HTTPServer.Password != "" {
s.globalConfig.HTTP.Password = req.HTTPServer.Password
}
// 更新屏蔽配置
if req.Shield.BlockMethod != "" {
@@ -1449,9 +1858,23 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
}
// 更新更新间隔
if req.Shield.UpdateInterval > 0 {
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
// 重新启动自动更新
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
}
// 更新统计保存间隔
if req.Shield.StatsSaveInterval > 0 {
s.globalConfig.Shield.StatsSaveInterval = req.Shield.StatsSaveInterval
}
// 更新GFWList配置
s.globalConfig.GFWList.IP = req.GFWList.IP
s.globalConfig.GFWList.Content = req.GFWList.Content
s.globalConfig.GFWList.Enabled = req.GFWList.Enabled
// 重新加载GFWList规则
if s.gfwManager != nil {
@@ -1481,14 +1904,6 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
}
}
// 更新更新间隔
if req.Shield.UpdateInterval > 0 {
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
// 重新启动自动更新
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
}
// 更新现有的DNSCache实例配置
// 最大和最小TTL(秒)
maxCacheTTL := time.Duration(s.globalConfig.DNS.MaxCacheTTL) * time.Second
@@ -1503,7 +1918,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
s.dnsServer.DnsCache.SetMaxCacheSize(maxCacheSize)
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
if err := saveConfigToFile(s.globalConfig, "./config.ini"); err != nil {
logger.Error("保存配置到文件失败", "error", err)
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
}
@@ -1568,9 +1983,27 @@ func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
return
}
// 获取日志统计数据
// 构建缓存键
cacheKey := "logs_stats"
// 如果启用缓存,先尝试从缓存获取
if s.cacheEnabled {
if cachedStats, found := s.statsCache.Get(cacheKey); found {
// 缓存命中,直接返回
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cachedStats)
return
}
}
// 缓存未命中,获取最新统计数据
logStats := s.dnsServer.GetQueryStats()
// 存入缓存
if s.cacheEnabled {
s.statsCache.Set(cacheKey, logStats)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logStats)
}
@@ -1583,7 +2016,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
}
// 获取查询参数
limit := 100 // 默认返回100条日志
limit := 100 // 默认返回 100 条日志
offset := 0
sortField := r.URL.Query().Get("sort")
sortDirection := r.URL.Query().Get("direction")
@@ -1599,9 +2032,18 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 获取日志数据
// 构建缓存键,包含所有查询参数
// 已禁用缓存,每次都从数据库获取最新数据
// cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 缓存未命中,获取日志数据(已禁用缓存)
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 存入缓存(已禁用)
// if s.cacheEnabled {
// s.queryCache.Set(cacheKey, logs)
// }
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logs)
}
@@ -1618,9 +2060,17 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
// 获取带过滤条件的日志总数
// 构建缓存键(已禁用)
// cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
// 缓存未命中,获取带过滤条件的日志总数(已禁用缓存)
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
// 存入缓存(已禁用)
// if s.cacheEnabled {
// s.queryCache.Set(cacheKey, count)
// }
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
}
@@ -1828,6 +2278,159 @@ func isService(obj map[string]interface{}) bool {
return false
}
// handleAlert 处理威胁告警请求
func (s *Server) handleAlert(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 获取告警列表
limit := 100
offset := 0
level := r.URL.Query().Get("level")
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
fmt.Sscanf(limitStr, "%d", &limit)
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 获取告警列表
alerts := s.dnsServer.GetAlerts(limit, offset, level)
// 构建响应
response := map[string]interface{}{
"alerts": alerts,
"total": s.dnsServer.GetAlertCount(level),
}
json.NewEncoder(w).Encode(response)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleAlertResolve 处理威胁告警解决请求
func (s *Server) handleAlertResolve(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
AlertID string `json:"alertId"`
Action string `json:"action"` // blocked, allowed
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.AlertID == "" || req.Action == "" {
http.Error(w, "AlertID and Action are required", http.StatusBadRequest)
return
}
// 验证动作
if req.Action != threat.ActionBlocked && req.Action != threat.ActionAllowed {
http.Error(w, "Invalid action", http.StatusBadRequest)
return
}
// 解决告警
success := s.dnsServer.ResolveAlert(req.AlertID, req.Action)
if success {
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
} else {
http.Error(w, "Failed to resolve alert", http.StatusInternalServerError)
}
}
// handleThreatDomain 处理威胁域名管理请求
func (s *Server) handleThreatDomain(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 获取所有威胁域名
threats := s.dnsServer.GetThreatDomains()
json.NewEncoder(w).Encode(threats)
case http.MethodPost:
// 添加威胁域名
var req struct {
Type string `json:"type"`
Name string `json:"name"`
RiskLevel int `json:"riskLevel"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Domain == "" {
http.Error(w, "Domain is required", http.StatusBadRequest)
return
}
// 设置默认值
if req.Type == "" {
req.Type = "未知"
}
if req.Name == "" {
req.Name = "未知"
}
if req.RiskLevel == 0 {
req.RiskLevel = 1
}
err := s.dnsServer.AddThreatDomain(req.Type, req.Name, req.RiskLevel, req.Domain)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除威胁域名
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Domain == "" {
http.Error(w, "Domain is required", http.StatusBadRequest)
return
}
err := s.dnsServer.RemoveThreatDomain(req.Domain)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// processServiceItem 递归处理服务或分组
func processServiceItem(
serviceName string,