whois
This commit is contained in:
+657
-54
@@ -19,9 +19,294 @@ import (
|
||||
"dns-server/logger"
|
||||
"dns-server/shield"
|
||||
|
||||
"gopkg.in/ini.v1"
|
||||
"dns-server/threat"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// CacheEntry 缓存条目
|
||||
type CacheEntry struct {
|
||||
data interface{} // 缓存的数据
|
||||
timestamp time.Time // 缓存创建时间
|
||||
hits int // 命中次数(用于 LRU)
|
||||
}
|
||||
|
||||
// QueryCache 查询结果缓存
|
||||
type QueryCache struct {
|
||||
data map[string]*CacheEntry // 缓存键 -> 缓存条目
|
||||
mutex sync.RWMutex // 读写锁
|
||||
maxSize int // 最大缓存条目数
|
||||
ttl time.Duration // 缓存有效期
|
||||
}
|
||||
|
||||
// StatsCache 统计数据缓存
|
||||
type StatsCache struct {
|
||||
data map[string]*CacheEntry // 缓存键 -> 缓存条目
|
||||
mutex sync.RWMutex // 读写锁
|
||||
maxSize int // 最大缓存条目数
|
||||
ttl time.Duration // 缓存有效期
|
||||
lastStats map[string]interface{} // 上次缓存的统计数据,用于增量更新
|
||||
}
|
||||
|
||||
// NewQueryCache 创建查询结果缓存
|
||||
func NewQueryCache(maxSize int, ttl time.Duration) *QueryCache {
|
||||
cache := &QueryCache{
|
||||
data: make(map[string]*CacheEntry),
|
||||
maxSize: maxSize,
|
||||
ttl: ttl,
|
||||
}
|
||||
// 启动缓存清理协程
|
||||
go cache.startCleanupLoop()
|
||||
return cache
|
||||
}
|
||||
|
||||
// NewStatsCache 创建统计数据缓存
|
||||
func NewStatsCache(maxSize int, ttl time.Duration) *StatsCache {
|
||||
cache := &StatsCache{
|
||||
data: make(map[string]*CacheEntry),
|
||||
maxSize: maxSize,
|
||||
ttl: ttl,
|
||||
lastStats: make(map[string]interface{}),
|
||||
}
|
||||
// 启动缓存清理协程
|
||||
go cache.startCleanupLoop()
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get 获取缓存条目
|
||||
func (c *QueryCache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
entry, found := c.data[key]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Since(entry.timestamp) > c.ttl {
|
||||
// 过期,删除
|
||||
c.mutex.Lock()
|
||||
delete(c.data, key)
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新命中次数
|
||||
c.mutex.Lock()
|
||||
entry.hits++
|
||||
c.mutex.Unlock()
|
||||
|
||||
return entry.data, true
|
||||
}
|
||||
|
||||
// Set 设置缓存条目
|
||||
func (c *QueryCache) Set(key string, data interface{}) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// 如果缓存已满,删除最少使用的条目
|
||||
if len(c.data) >= c.maxSize {
|
||||
c.evictLRU()
|
||||
}
|
||||
|
||||
c.data[key] = &CacheEntry{
|
||||
data: data,
|
||||
timestamp: time.Now(),
|
||||
hits: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 删除缓存条目
|
||||
func (c *QueryCache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.data, key)
|
||||
}
|
||||
|
||||
// Clear 清空缓存
|
||||
func (c *QueryCache) Clear() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.data = make(map[string]*CacheEntry)
|
||||
}
|
||||
|
||||
// evictLRU 淘汰最少使用的条目
|
||||
func (c *QueryCache) evictLRU() {
|
||||
var lruKey string
|
||||
minHits := int(^uint(0) >> 1) // 最大 int 值
|
||||
|
||||
for key, entry := range c.data {
|
||||
if entry.hits < minHits {
|
||||
minHits = entry.hits
|
||||
lruKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if lruKey != "" {
|
||||
delete(c.data, lruKey)
|
||||
}
|
||||
}
|
||||
|
||||
// startCleanupLoop 启动清理协程
|
||||
func (c *QueryCache) startCleanupLoop() {
|
||||
ticker := time.NewTicker(c.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanupExpired()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired 清理过期条目
|
||||
func (c *QueryCache) cleanupExpired() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, entry := range c.data {
|
||||
if now.Sub(entry.timestamp) > c.ttl {
|
||||
delete(c.data, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StatsCache 方法
|
||||
|
||||
// Get 获取统计数据缓存条目
|
||||
func (c *StatsCache) Get(key string) (map[string]interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
entry, found := c.data[key]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Since(entry.timestamp) > c.ttl {
|
||||
// 过期,删除
|
||||
c.mutex.Lock()
|
||||
delete(c.data, key)
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新命中次数
|
||||
c.mutex.Lock()
|
||||
entry.hits++
|
||||
c.mutex.Unlock()
|
||||
|
||||
if data, ok := entry.data.(map[string]interface{}); ok {
|
||||
return data, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Set 设置统计数据缓存条目
|
||||
func (c *StatsCache) Set(key string, data map[string]interface{}) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// 如果缓存已满,删除最少使用的条目
|
||||
if len(c.data) >= c.maxSize {
|
||||
c.evictLRU()
|
||||
}
|
||||
|
||||
c.data[key] = &CacheEntry{
|
||||
data: data,
|
||||
timestamp: time.Now(),
|
||||
hits: 0,
|
||||
}
|
||||
|
||||
// 保存最后统计数据
|
||||
c.lastStats = data
|
||||
}
|
||||
|
||||
// GetLastStats 获取上次缓存的统计数据
|
||||
func (c *StatsCache) GetLastStats() map[string]interface{} {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
return c.lastStats
|
||||
}
|
||||
|
||||
// Delete 删除统计数据缓存条目
|
||||
func (c *StatsCache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.data, key)
|
||||
}
|
||||
|
||||
// Clear 清空统计数据缓存
|
||||
func (c *StatsCache) Clear() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.data = make(map[string]*CacheEntry)
|
||||
c.lastStats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// evictLRU 淘汰最少使用的条目
|
||||
func (c *StatsCache) evictLRU() {
|
||||
var lruKey string
|
||||
minHits := int(^uint(0) >> 1) // 最大 int 值
|
||||
|
||||
for key, entry := range c.data {
|
||||
if entry.hits < minHits {
|
||||
minHits = entry.hits
|
||||
lruKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if lruKey != "" {
|
||||
delete(c.data, lruKey)
|
||||
}
|
||||
}
|
||||
|
||||
// startCleanupLoop 启动清理协程
|
||||
func (c *StatsCache) startCleanupLoop() {
|
||||
ticker := time.NewTicker(c.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.cleanupExpired()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired 清理过期条目
|
||||
func (c *StatsCache) cleanupExpired() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, entry := range c.data {
|
||||
if now.Sub(entry.timestamp) > c.ttl {
|
||||
delete(c.data, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClearQueryCache 清除查询缓存
|
||||
func (s *Server) ClearQueryCache() {
|
||||
if s.queryCache != nil {
|
||||
s.queryCache.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
// ClearStatsCache 清除统计缓存
|
||||
func (s *Server) ClearStatsCache() {
|
||||
if s.statsCache != nil {
|
||||
s.statsCache.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
// ClearAllCache 清除所有缓存
|
||||
func (s *Server) ClearAllCache() {
|
||||
s.ClearQueryCache()
|
||||
s.ClearStatsCache()
|
||||
}
|
||||
|
||||
// Server HTTP控制台服务器
|
||||
type Server struct {
|
||||
globalConfig *config.Config
|
||||
@@ -36,11 +321,18 @@ type Server struct {
|
||||
sessionsMutex sync.Mutex // 会话映射的互斥锁
|
||||
sessionTTL time.Duration // 会话过期时间
|
||||
|
||||
// WebSocket相关字段
|
||||
// WebSocket 相关字段
|
||||
upgrader websocket.Upgrader
|
||||
clients map[*websocket.Conn]bool
|
||||
clientsMutex sync.Mutex
|
||||
broadcastChan chan []byte
|
||||
|
||||
// 查询缓存相关字段
|
||||
queryCache *QueryCache // 查询结果缓存
|
||||
statsCache *StatsCache // 统计数据缓存
|
||||
cacheEnabled bool // 缓存是否启用
|
||||
cacheTTL time.Duration // 缓存过期时间
|
||||
cacheMaxSize int // 缓存最大条目数
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
@@ -63,7 +355,13 @@ func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager
|
||||
broadcastChan: make(chan []byte, 100),
|
||||
// 会话管理初始化
|
||||
sessions: make(map[string]time.Time),
|
||||
sessionTTL: 24 * time.Hour, // 会话有效期24小时
|
||||
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
|
||||
// 查询缓存初始化
|
||||
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
|
||||
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
|
||||
cacheEnabled: true, // 默认启用缓存
|
||||
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
|
||||
cacheMaxSize: 100, // 默认最大 100 条
|
||||
}
|
||||
|
||||
// 启动广播协程
|
||||
@@ -150,6 +448,12 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
|
||||
// 威胁批量查询接口
|
||||
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
|
||||
// 威胁告警接口
|
||||
mux.HandleFunc("/api/alert", s.loginRequired(s.handleAlert))
|
||||
// 威胁告警解决接口
|
||||
mux.HandleFunc("/api/alert/resolve", s.loginRequired(s.handleAlertResolve))
|
||||
// 威胁域名管理接口
|
||||
mux.HandleFunc("/api/threat/domain", s.loginRequired(s.handleThreatDomain))
|
||||
// WebSocket 端点
|
||||
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
|
||||
|
||||
@@ -961,7 +1265,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
|
||||
// 更新全局配置中的黑名单
|
||||
s.globalConfig.Shield.Blacklists = blacklists
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||||
return
|
||||
@@ -991,7 +1295,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
|
||||
// 更新全局配置中的黑名单
|
||||
s.globalConfig.Shield.Blacklists = newBlacklists
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||||
return
|
||||
@@ -1052,7 +1356,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
|
||||
// 更新全局配置中的黑名单
|
||||
s.globalConfig.Shield.Blacklists = blacklists
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||||
return
|
||||
@@ -1093,7 +1397,7 @@ func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request)
|
||||
// 更新全局配置中的黑名单
|
||||
s.globalConfig.Shield.Blacklists = newBlacklists
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
|
||||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||||
logger.Error("保存配置文件失败", "error", err)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||||
return
|
||||
@@ -1280,11 +1584,78 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// saveConfigToFile 保存配置到文件
|
||||
func saveConfigToFile(config *config.Config, filePath string) error {
|
||||
data, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
// 创建新的INI文件
|
||||
cfg := ini.Empty()
|
||||
|
||||
// DNS配置
|
||||
dnsSection := cfg.Section("dns")
|
||||
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
|
||||
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
|
||||
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
|
||||
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
|
||||
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
|
||||
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
|
||||
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
|
||||
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
|
||||
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
|
||||
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
|
||||
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
|
||||
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
|
||||
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
|
||||
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
|
||||
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
|
||||
|
||||
// 域名特定DNS服务器配置
|
||||
for domain, servers := range config.DNS.DomainSpecificDNS {
|
||||
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
|
||||
}
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
|
||||
// HTTP配置
|
||||
httpSection := cfg.Section("http")
|
||||
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
|
||||
httpSection.Key("host").SetValue(config.HTTP.Host)
|
||||
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
|
||||
httpSection.Key("username").SetValue(config.HTTP.Username)
|
||||
httpSection.Key("password").SetValue(config.HTTP.Password)
|
||||
|
||||
// Shield配置
|
||||
shieldSection := cfg.Section("shield")
|
||||
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
|
||||
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
|
||||
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
|
||||
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
|
||||
|
||||
// 黑名单配置
|
||||
for _, bl := range config.Shield.Blacklists {
|
||||
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
|
||||
}
|
||||
|
||||
// GFWList配置
|
||||
gfwListSection := cfg.Section("gfwList")
|
||||
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
|
||||
gfwListSection.Key("content").SetValue(config.GFWList.Content)
|
||||
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
|
||||
|
||||
// Log配置
|
||||
logSection := cfg.Section("log")
|
||||
logSection.Key("level").SetValue(config.Log.Level)
|
||||
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
|
||||
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
|
||||
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
|
||||
|
||||
// Threat配置
|
||||
threatSection := cfg.Section("threat")
|
||||
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
|
||||
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
|
||||
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
|
||||
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
|
||||
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ","))
|
||||
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ","))
|
||||
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
|
||||
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
|
||||
|
||||
// 保存到文件
|
||||
return cfg.SaveTo(filePath)
|
||||
}
|
||||
|
||||
// handleConfig 处理配置请求
|
||||
@@ -1293,35 +1664,52 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 每次从配置文件重新读取最新配置
|
||||
cfg, err := config.LoadConfig("config.ini")
|
||||
if err != nil {
|
||||
logger.Error("加载配置文件失败", "error", err)
|
||||
// 如果加载失败,返回内存中的配置
|
||||
cfg = s.globalConfig
|
||||
}
|
||||
|
||||
// 返回当前配置(包括黑名单配置)
|
||||
// 注意:key 名必须与前端期望的一致
|
||||
config := map[string]interface{}{
|
||||
"Shield": map[string]interface{}{
|
||||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||||
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
|
||||
"blacklists": s.globalConfig.Shield.Blacklists,
|
||||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||||
"blockMethod": cfg.Shield.BlockMethod,
|
||||
"customBlockIP": cfg.Shield.CustomBlockIP,
|
||||
"blacklists": cfg.Shield.Blacklists,
|
||||
"updateInterval": cfg.Shield.UpdateInterval,
|
||||
"statsSaveInterval": cfg.Shield.StatsSaveInterval,
|
||||
},
|
||||
"GFWList": map[string]interface{}{
|
||||
"ip": s.globalConfig.GFWList.IP,
|
||||
"content": s.globalConfig.GFWList.Content,
|
||||
"ip": cfg.GFWList.IP,
|
||||
"content": cfg.GFWList.Content,
|
||||
"enabled": cfg.GFWList.Enabled,
|
||||
},
|
||||
"DNSServer": map[string]interface{}{
|
||||
"port": s.globalConfig.DNS.Port,
|
||||
"QueryMode": s.globalConfig.DNS.QueryMode,
|
||||
"UpstreamServers": s.globalConfig.DNS.UpstreamDNS,
|
||||
"DNSSECUpstreamServers": s.globalConfig.DNS.DNSSECUpstreamDNS,
|
||||
"saveInterval": s.globalConfig.DNS.SaveInterval,
|
||||
"enableIPv6": s.globalConfig.DNS.EnableIPv6,
|
||||
"CacheMode": s.globalConfig.DNS.CacheMode,
|
||||
"CacheSize": s.globalConfig.DNS.CacheSize,
|
||||
"MaxCacheTTL": s.globalConfig.DNS.MaxCacheTTL,
|
||||
"MinCacheTTL": s.globalConfig.DNS.MinCacheTTL,
|
||||
"enableFastReturn": s.globalConfig.DNS.EnableFastReturn,
|
||||
"domainSpecificDNS": s.globalConfig.DNS.DomainSpecificDNS,
|
||||
"noDNSSECDomains": s.globalConfig.DNS.NoDNSSECDomains,
|
||||
"port": cfg.DNS.Port,
|
||||
"QueryMode": cfg.DNS.QueryMode,
|
||||
"UpstreamServers": cfg.DNS.UpstreamDNS,
|
||||
"DNSSECUpstreamServers": cfg.DNS.DNSSECUpstreamDNS,
|
||||
"saveInterval": cfg.DNS.SaveInterval,
|
||||
"queryTimeout": cfg.DNS.QueryTimeout,
|
||||
"enableIPv6": cfg.DNS.EnableIPv6,
|
||||
"enableDNSSEC": cfg.DNS.EnableDNSSEC,
|
||||
"enableFastReturn": cfg.DNS.EnableFastReturn,
|
||||
"noDNSSECDomains": cfg.DNS.NoDNSSECDomains,
|
||||
"CacheMode": cfg.DNS.CacheMode,
|
||||
"CacheSize": cfg.DNS.CacheSize,
|
||||
"MaxCacheTTL": cfg.DNS.MaxCacheTTL,
|
||||
"MinCacheTTL": cfg.DNS.MinCacheTTL,
|
||||
"domainSpecificDNS": cfg.DNS.DomainSpecificDNS,
|
||||
},
|
||||
"HTTPServer": map[string]interface{}{
|
||||
"port": s.globalConfig.HTTP.Port,
|
||||
"port": cfg.HTTP.Port,
|
||||
"host": cfg.HTTP.Host,
|
||||
"enableAPI": cfg.HTTP.EnableAPI,
|
||||
"username": cfg.HTTP.Username,
|
||||
"password": cfg.HTTP.Password,
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(config)
|
||||
@@ -1337,26 +1725,33 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
Timeout int `json:"timeout"`
|
||||
SaveInterval int `json:"saveInterval"`
|
||||
EnableIPv6 bool `json:"enableIPv6"`
|
||||
EnableDNSSEC bool `json:"enableDNSSEC"`
|
||||
EnableFastReturn *bool `json:"enableFastReturn"`
|
||||
NoDNSSECDomains []string `json:"noDNSSECDomains"`
|
||||
CacheMode string `json:"cacheMode"`
|
||||
CacheSize int `json:"cacheSize"`
|
||||
MaxCacheTTL int `json:"maxCacheTTL"`
|
||||
MinCacheTTL int `json:"minCacheTTL"`
|
||||
EnableFastReturn *bool `json:"enableFastReturn"`
|
||||
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
|
||||
NoDNSSECDomains []string `json:"noDNSSECDomains"`
|
||||
} `json:"dnsserver"`
|
||||
HTTPServer struct {
|
||||
Port int `json:"port"`
|
||||
Port int `json:"port"`
|
||||
Host string `json:"host"`
|
||||
EnableAPI bool `json:"enableAPI"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
} `json:"httpserver"`
|
||||
Shield struct {
|
||||
BlockMethod string `json:"blockMethod"`
|
||||
CustomBlockIP string `json:"customBlockIP"`
|
||||
Blacklists []config.BlacklistEntry `json:"blacklists"`
|
||||
UpdateInterval int `json:"updateInterval"`
|
||||
BlockMethod string `json:"blockMethod"`
|
||||
CustomBlockIP string `json:"customBlockIP"`
|
||||
Blacklists []config.BlacklistEntry `json:"blacklists"`
|
||||
UpdateInterval int `json:"updateInterval"`
|
||||
StatsSaveInterval int `json:"statsSaveInterval"`
|
||||
} `json:"shield"`
|
||||
GFWList struct {
|
||||
IP string `json:"ip"`
|
||||
Content string `json:"content"`
|
||||
Enabled bool `json:"enabled"`
|
||||
} `json:"gfwList"`
|
||||
}
|
||||
|
||||
@@ -1378,7 +1773,11 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if req.DNSServer.SaveInterval > 0 {
|
||||
s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval
|
||||
}
|
||||
if req.DNSServer.Timeout > 0 {
|
||||
s.globalConfig.DNS.QueryTimeout = req.DNSServer.Timeout
|
||||
}
|
||||
s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6
|
||||
s.globalConfig.DNS.EnableDNSSEC = req.DNSServer.EnableDNSSEC
|
||||
// 更新查询模式
|
||||
if req.DNSServer.QueryMode != "" {
|
||||
s.globalConfig.DNS.QueryMode = req.DNSServer.QueryMode
|
||||
@@ -1400,19 +1799,29 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if req.DNSServer.EnableFastReturn != nil {
|
||||
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
|
||||
}
|
||||
// 更新domainSpecificDNS
|
||||
if req.DNSServer.DomainSpecificDNS != nil {
|
||||
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
|
||||
}
|
||||
// 更新noDNSSECDomains
|
||||
if len(req.DNSServer.NoDNSSECDomains) > 0 {
|
||||
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
|
||||
}
|
||||
// 更新domainSpecificDNS
|
||||
if req.DNSServer.DomainSpecificDNS != nil {
|
||||
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
|
||||
}
|
||||
|
||||
// 更新HTTP配置
|
||||
if req.HTTPServer.Port > 0 {
|
||||
s.globalConfig.HTTP.Port = req.HTTPServer.Port
|
||||
}
|
||||
if req.HTTPServer.Host != "" {
|
||||
s.globalConfig.HTTP.Host = req.HTTPServer.Host
|
||||
}
|
||||
s.globalConfig.HTTP.EnableAPI = req.HTTPServer.EnableAPI
|
||||
if req.HTTPServer.Username != "" {
|
||||
s.globalConfig.HTTP.Username = req.HTTPServer.Username
|
||||
}
|
||||
if req.HTTPServer.Password != "" {
|
||||
s.globalConfig.HTTP.Password = req.HTTPServer.Password
|
||||
}
|
||||
|
||||
// 更新屏蔽配置
|
||||
if req.Shield.BlockMethod != "" {
|
||||
@@ -1449,9 +1858,23 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
|
||||
}
|
||||
|
||||
// 更新更新间隔
|
||||
if req.Shield.UpdateInterval > 0 {
|
||||
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
|
||||
// 重新启动自动更新
|
||||
s.shieldManager.StopAutoUpdate()
|
||||
s.shieldManager.StartAutoUpdate()
|
||||
}
|
||||
|
||||
// 更新统计保存间隔
|
||||
if req.Shield.StatsSaveInterval > 0 {
|
||||
s.globalConfig.Shield.StatsSaveInterval = req.Shield.StatsSaveInterval
|
||||
}
|
||||
|
||||
// 更新GFWList配置
|
||||
s.globalConfig.GFWList.IP = req.GFWList.IP
|
||||
s.globalConfig.GFWList.Content = req.GFWList.Content
|
||||
s.globalConfig.GFWList.Enabled = req.GFWList.Enabled
|
||||
|
||||
// 重新加载GFWList规则
|
||||
if s.gfwManager != nil {
|
||||
@@ -1481,14 +1904,6 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// 更新更新间隔
|
||||
if req.Shield.UpdateInterval > 0 {
|
||||
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
|
||||
// 重新启动自动更新
|
||||
s.shieldManager.StopAutoUpdate()
|
||||
s.shieldManager.StartAutoUpdate()
|
||||
}
|
||||
|
||||
// 更新现有的DNSCache实例配置
|
||||
// 最大和最小TTL(秒)
|
||||
maxCacheTTL := time.Duration(s.globalConfig.DNS.MaxCacheTTL) * time.Second
|
||||
@@ -1503,7 +1918,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
s.dnsServer.DnsCache.SetMaxCacheSize(maxCacheSize)
|
||||
|
||||
// 保存配置到文件
|
||||
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
|
||||
if err := saveConfigToFile(s.globalConfig, "./config.ini"); err != nil {
|
||||
logger.Error("保存配置到文件失败", "error", err)
|
||||
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
|
||||
}
|
||||
@@ -1568,9 +1983,27 @@ func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取日志统计数据
|
||||
// 构建缓存键
|
||||
cacheKey := "logs_stats"
|
||||
|
||||
// 如果启用缓存,先尝试从缓存获取
|
||||
if s.cacheEnabled {
|
||||
if cachedStats, found := s.statsCache.Get(cacheKey); found {
|
||||
// 缓存命中,直接返回
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(cachedStats)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,获取最新统计数据
|
||||
logStats := s.dnsServer.GetQueryStats()
|
||||
|
||||
// 存入缓存
|
||||
if s.cacheEnabled {
|
||||
s.statsCache.Set(cacheKey, logStats)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(logStats)
|
||||
}
|
||||
@@ -1583,7 +2016,7 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
limit := 100 // 默认返回100条日志
|
||||
limit := 100 // 默认返回 100 条日志
|
||||
offset := 0
|
||||
sortField := r.URL.Query().Get("sort")
|
||||
sortDirection := r.URL.Query().Get("direction")
|
||||
@@ -1599,9 +2032,18 @@ func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Sscanf(offsetStr, "%d", &offset)
|
||||
}
|
||||
|
||||
// 获取日志数据
|
||||
// 构建缓存键,包含所有查询参数
|
||||
// 已禁用缓存,每次都从数据库获取最新数据
|
||||
// cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
|
||||
// 缓存未命中,获取日志数据(已禁用缓存)
|
||||
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||||
|
||||
// 存入缓存(已禁用)
|
||||
// if s.cacheEnabled {
|
||||
// s.queryCache.Set(cacheKey, logs)
|
||||
// }
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(logs)
|
||||
}
|
||||
@@ -1618,9 +2060,17 @@ func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
|
||||
searchTerm := r.URL.Query().Get("search")
|
||||
queryType := r.URL.Query().Get("queryType")
|
||||
|
||||
// 获取带过滤条件的日志总数
|
||||
// 构建缓存键(已禁用)
|
||||
// cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
|
||||
|
||||
// 缓存未命中,获取带过滤条件的日志总数(已禁用缓存)
|
||||
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
|
||||
|
||||
// 存入缓存(已禁用)
|
||||
// if s.cacheEnabled {
|
||||
// s.queryCache.Set(cacheKey, count)
|
||||
// }
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]int{"count": count})
|
||||
}
|
||||
@@ -1828,6 +2278,159 @@ func isService(obj map[string]interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// handleAlert 处理威胁告警请求
|
||||
func (s *Server) handleAlert(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 获取告警列表
|
||||
limit := 100
|
||||
offset := 0
|
||||
level := r.URL.Query().Get("level")
|
||||
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
fmt.Sscanf(limitStr, "%d", &limit)
|
||||
}
|
||||
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
fmt.Sscanf(offsetStr, "%d", &offset)
|
||||
}
|
||||
|
||||
// 获取告警列表
|
||||
alerts := s.dnsServer.GetAlerts(limit, offset, level)
|
||||
|
||||
// 构建响应
|
||||
response := map[string]interface{}{
|
||||
"alerts": alerts,
|
||||
"total": s.dnsServer.GetAlertCount(level),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleAlertResolve 处理威胁告警解决请求
|
||||
func (s *Server) handleAlertResolve(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
AlertID string `json:"alertId"`
|
||||
Action string `json:"action"` // blocked, allowed
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AlertID == "" || req.Action == "" {
|
||||
http.Error(w, "AlertID and Action are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证动作
|
||||
if req.Action != threat.ActionBlocked && req.Action != threat.ActionAllowed {
|
||||
http.Error(w, "Invalid action", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 解决告警
|
||||
success := s.dnsServer.ResolveAlert(req.AlertID, req.Action)
|
||||
|
||||
if success {
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
} else {
|
||||
http.Error(w, "Failed to resolve alert", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// handleThreatDomain 处理威胁域名管理请求
|
||||
func (s *Server) handleThreatDomain(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 获取所有威胁域名
|
||||
threats := s.dnsServer.GetThreatDomains()
|
||||
json.NewEncoder(w).Encode(threats)
|
||||
|
||||
case http.MethodPost:
|
||||
// 添加威胁域名
|
||||
var req struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
RiskLevel int `json:"riskLevel"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Domain == "" {
|
||||
http.Error(w, "Domain is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if req.Type == "" {
|
||||
req.Type = "未知"
|
||||
}
|
||||
if req.Name == "" {
|
||||
req.Name = "未知"
|
||||
}
|
||||
if req.RiskLevel == 0 {
|
||||
req.RiskLevel = 1
|
||||
}
|
||||
|
||||
err := s.dnsServer.AddThreatDomain(req.Type, req.Name, req.RiskLevel, req.Domain)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
case http.MethodDelete:
|
||||
// 删除威胁域名
|
||||
var req struct {
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Domain == "" {
|
||||
http.Error(w, "Domain is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := s.dnsServer.RemoveThreatDomain(req.Domain)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// processServiceItem 递归处理服务或分组
|
||||
func processServiceItem(
|
||||
serviceName string,
|
||||
|
||||
Reference in New Issue
Block a user