2994 lines
90 KiB
Go
2994 lines
90 KiB
Go
package http
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/csv"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"os"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"dns-server/config"
|
||
"dns-server/dns"
|
||
"dns-server/domain"
|
||
"dns-server/gfw"
|
||
"dns-server/logger"
|
||
"dns-server/shield"
|
||
|
||
"gopkg.in/ini.v1"
|
||
"dns-server/threat"
|
||
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
// CacheEntry 缓存条目
|
||
type CacheEntry struct {
|
||
data interface{} // 缓存的数据
|
||
timestamp time.Time // 缓存创建时间
|
||
hits int // 命中次数(用于 LRU)
|
||
}
|
||
|
||
// QueryCache 查询结果缓存
|
||
type QueryCache struct {
|
||
data map[string]*CacheEntry // 缓存键 -> 缓存条目
|
||
mutex sync.RWMutex // 读写锁
|
||
maxSize int // 最大缓存条目数
|
||
ttl time.Duration // 缓存有效期
|
||
}
|
||
|
||
// StatsCache 统计数据缓存
|
||
type StatsCache struct {
|
||
data map[string]*CacheEntry // 缓存键 -> 缓存条目
|
||
mutex sync.RWMutex // 读写锁
|
||
maxSize int // 最大缓存条目数
|
||
ttl time.Duration // 缓存有效期
|
||
lastStats map[string]interface{} // 上次缓存的统计数据,用于增量更新
|
||
}
|
||
|
||
// NewQueryCache 创建查询结果缓存
|
||
func NewQueryCache(maxSize int, ttl time.Duration) *QueryCache {
|
||
cache := &QueryCache{
|
||
data: make(map[string]*CacheEntry),
|
||
maxSize: maxSize,
|
||
ttl: ttl,
|
||
}
|
||
// 启动缓存清理协程
|
||
go cache.startCleanupLoop()
|
||
return cache
|
||
}
|
||
|
||
// NewStatsCache 创建统计数据缓存
|
||
func NewStatsCache(maxSize int, ttl time.Duration) *StatsCache {
|
||
cache := &StatsCache{
|
||
data: make(map[string]*CacheEntry),
|
||
maxSize: maxSize,
|
||
ttl: ttl,
|
||
lastStats: make(map[string]interface{}),
|
||
}
|
||
// 启动缓存清理协程
|
||
go cache.startCleanupLoop()
|
||
return cache
|
||
}
|
||
|
||
// Get 获取缓存条目
|
||
func (c *QueryCache) Get(key string) (interface{}, bool) {
|
||
c.mutex.RLock()
|
||
entry, found := c.data[key]
|
||
c.mutex.RUnlock()
|
||
|
||
if !found {
|
||
return nil, false
|
||
}
|
||
|
||
// 检查是否过期
|
||
if time.Since(entry.timestamp) > c.ttl {
|
||
// 过期,删除
|
||
c.mutex.Lock()
|
||
delete(c.data, key)
|
||
c.mutex.Unlock()
|
||
return nil, false
|
||
}
|
||
|
||
// 更新命中次数
|
||
c.mutex.Lock()
|
||
entry.hits++
|
||
c.mutex.Unlock()
|
||
|
||
return entry.data, true
|
||
}
|
||
|
||
// Set 设置缓存条目
|
||
func (c *QueryCache) Set(key string, data interface{}) {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
// 如果缓存已满,删除最少使用的条目
|
||
if len(c.data) >= c.maxSize {
|
||
c.evictLRU()
|
||
}
|
||
|
||
c.data[key] = &CacheEntry{
|
||
data: data,
|
||
timestamp: time.Now(),
|
||
hits: 0,
|
||
}
|
||
}
|
||
|
||
// Delete 删除缓存条目
|
||
func (c *QueryCache) Delete(key string) {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
delete(c.data, key)
|
||
}
|
||
|
||
// Clear 清空缓存
|
||
func (c *QueryCache) Clear() {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
c.data = make(map[string]*CacheEntry)
|
||
}
|
||
|
||
// evictLRU 淘汰最少使用的条目
|
||
func (c *QueryCache) evictLRU() {
|
||
var lruKey string
|
||
minHits := int(^uint(0) >> 1) // 最大 int 值
|
||
|
||
for key, entry := range c.data {
|
||
if entry.hits < minHits {
|
||
minHits = entry.hits
|
||
lruKey = key
|
||
}
|
||
}
|
||
|
||
if lruKey != "" {
|
||
delete(c.data, lruKey)
|
||
}
|
||
}
|
||
|
||
// startCleanupLoop 启动清理协程
|
||
func (c *QueryCache) startCleanupLoop() {
|
||
ticker := time.NewTicker(c.ttl / 2)
|
||
defer ticker.Stop()
|
||
|
||
for range ticker.C {
|
||
c.cleanupExpired()
|
||
}
|
||
}
|
||
|
||
// cleanupExpired 清理过期条目
|
||
func (c *QueryCache) cleanupExpired() {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
now := time.Now()
|
||
for key, entry := range c.data {
|
||
if now.Sub(entry.timestamp) > c.ttl {
|
||
delete(c.data, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
// StatsCache 方法
|
||
|
||
// Get 获取统计数据缓存条目
|
||
func (c *StatsCache) Get(key string) (map[string]interface{}, bool) {
|
||
c.mutex.RLock()
|
||
entry, found := c.data[key]
|
||
c.mutex.RUnlock()
|
||
|
||
if !found {
|
||
return nil, false
|
||
}
|
||
|
||
// 检查是否过期
|
||
if time.Since(entry.timestamp) > c.ttl {
|
||
// 过期,删除
|
||
c.mutex.Lock()
|
||
delete(c.data, key)
|
||
c.mutex.Unlock()
|
||
return nil, false
|
||
}
|
||
|
||
// 更新命中次数
|
||
c.mutex.Lock()
|
||
entry.hits++
|
||
c.mutex.Unlock()
|
||
|
||
if data, ok := entry.data.(map[string]interface{}); ok {
|
||
return data, true
|
||
}
|
||
|
||
return nil, false
|
||
}
|
||
|
||
// Set 设置统计数据缓存条目
|
||
func (c *StatsCache) Set(key string, data map[string]interface{}) {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
// 如果缓存已满,删除最少使用的条目
|
||
if len(c.data) >= c.maxSize {
|
||
c.evictLRU()
|
||
}
|
||
|
||
c.data[key] = &CacheEntry{
|
||
data: data,
|
||
timestamp: time.Now(),
|
||
hits: 0,
|
||
}
|
||
|
||
// 保存最后统计数据
|
||
c.lastStats = data
|
||
}
|
||
|
||
// GetLastStats 获取上次缓存的统计数据
|
||
func (c *StatsCache) GetLastStats() map[string]interface{} {
|
||
c.mutex.RLock()
|
||
defer c.mutex.RUnlock()
|
||
return c.lastStats
|
||
}
|
||
|
||
// Delete 删除统计数据缓存条目
|
||
func (c *StatsCache) Delete(key string) {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
delete(c.data, key)
|
||
}
|
||
|
||
// Clear 清空统计数据缓存
|
||
func (c *StatsCache) Clear() {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
c.data = make(map[string]*CacheEntry)
|
||
c.lastStats = make(map[string]interface{})
|
||
}
|
||
|
||
// evictLRU 淘汰最少使用的条目
|
||
func (c *StatsCache) evictLRU() {
|
||
var lruKey string
|
||
minHits := int(^uint(0) >> 1) // 最大 int 值
|
||
|
||
for key, entry := range c.data {
|
||
if entry.hits < minHits {
|
||
minHits = entry.hits
|
||
lruKey = key
|
||
}
|
||
}
|
||
|
||
if lruKey != "" {
|
||
delete(c.data, lruKey)
|
||
}
|
||
}
|
||
|
||
// startCleanupLoop 启动清理协程
|
||
func (c *StatsCache) startCleanupLoop() {
|
||
ticker := time.NewTicker(c.ttl / 2)
|
||
defer ticker.Stop()
|
||
|
||
for range ticker.C {
|
||
c.cleanupExpired()
|
||
}
|
||
}
|
||
|
||
// cleanupExpired 清理过期条目
|
||
func (c *StatsCache) cleanupExpired() {
|
||
c.mutex.Lock()
|
||
defer c.mutex.Unlock()
|
||
|
||
now := time.Now()
|
||
for key, entry := range c.data {
|
||
if now.Sub(entry.timestamp) > c.ttl {
|
||
delete(c.data, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ClearQueryCache 清除查询缓存
|
||
func (s *Server) ClearQueryCache() {
|
||
if s.queryCache != nil {
|
||
s.queryCache.Clear()
|
||
}
|
||
}
|
||
|
||
// ClearStatsCache 清除统计缓存
|
||
func (s *Server) ClearStatsCache() {
|
||
if s.statsCache != nil {
|
||
s.statsCache.Clear()
|
||
}
|
||
}
|
||
|
||
// ClearAllCache 清除所有缓存
|
||
func (s *Server) ClearAllCache() {
|
||
s.ClearQueryCache()
|
||
s.ClearStatsCache()
|
||
}
|
||
|
||
// Server HTTP控制台服务器
|
||
type Server struct {
|
||
globalConfig *config.Config
|
||
config *config.HTTPConfig
|
||
dnsServer *dns.Server
|
||
shieldManager *shield.ShieldManager
|
||
gfwManager *gfw.GFWListManager
|
||
server *http.Server
|
||
|
||
// 会话管理相关字段
|
||
sessions map[string]time.Time // 会话ID到过期时间的映射
|
||
sessionsMutex sync.Mutex // 会话映射的互斥锁
|
||
sessionTTL time.Duration // 会话过期时间
|
||
|
||
// WebSocket 相关字段
|
||
upgrader websocket.Upgrader
|
||
clients map[*websocket.Conn]bool
|
||
clientsMutex sync.Mutex
|
||
broadcastChan chan []byte
|
||
|
||
// 查询缓存相关字段
|
||
queryCache *QueryCache // 查询结果缓存
|
||
statsCache *StatsCache // 统计数据缓存
|
||
cacheEnabled bool // 缓存是否启用
|
||
cacheTTL time.Duration // 缓存过期时间
|
||
cacheMaxSize int // 缓存最大条目数
|
||
}
|
||
|
||
// NewServer 创建HTTP服务器实例
|
||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
|
||
server := &Server{
|
||
globalConfig: globalConfig,
|
||
config: &globalConfig.HTTP,
|
||
dnsServer: dnsServer,
|
||
shieldManager: shieldManager,
|
||
gfwManager: gfwManager,
|
||
upgrader: websocket.Upgrader{
|
||
ReadBufferSize: 1024,
|
||
WriteBufferSize: 1024,
|
||
// 允许所有CORS请求
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
},
|
||
clients: make(map[*websocket.Conn]bool),
|
||
broadcastChan: make(chan []byte, 100),
|
||
// 会话管理初始化
|
||
sessions: make(map[string]time.Time),
|
||
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
|
||
// 查询缓存初始化
|
||
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
|
||
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
|
||
cacheEnabled: true, // 默认启用缓存
|
||
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
|
||
cacheMaxSize: 100, // 默认最大 100 条
|
||
}
|
||
|
||
// 启动广播协程
|
||
go server.startBroadcastLoop()
|
||
// 启动会话清理协程
|
||
go server.cleanupSessionsLoop()
|
||
|
||
return server
|
||
}
|
||
|
||
// Start 启动HTTP服务器
|
||
func (s *Server) Start() error {
|
||
mux := http.NewServeMux()
|
||
|
||
// 登录路由,不需要认证
|
||
mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
||
// 重定向到登录页面HTML
|
||
http.Redirect(w, r, "/login.html", http.StatusFound)
|
||
})
|
||
|
||
// API路由
|
||
if s.config.EnableAPI {
|
||
// 登录API端点,不需要认证
|
||
mux.HandleFunc("/api/login", s.handleLogin)
|
||
// 注销API端点,不需要认证
|
||
mux.HandleFunc("/api/logout", s.handleLogout)
|
||
// 修改密码API端点,需要认证
|
||
mux.HandleFunc("/api/change-password", s.loginRequired(s.handleChangePassword))
|
||
|
||
// 重定向/api到Swagger UI页面
|
||
mux.HandleFunc("/api", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
|
||
http.Redirect(w, r, "/api/index.html", http.StatusMovedPermanently)
|
||
}))
|
||
|
||
// 注册所有API端点,应用登录中间件
|
||
mux.HandleFunc("/api/stats", s.loginRequired(s.handleStats))
|
||
mux.HandleFunc("/api/shield", s.loginRequired(s.handleShield))
|
||
mux.HandleFunc("/api/shield/localrules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if r.Method == http.MethodGet {
|
||
localRules := s.shieldManager.GetLocalRules()
|
||
json.NewEncoder(w).Encode(localRules)
|
||
return
|
||
}
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}))
|
||
mux.HandleFunc("/api/shield/remoterules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if r.Method == http.MethodGet {
|
||
remoteRules := s.shieldManager.GetRemoteRules()
|
||
json.NewEncoder(w).Encode(remoteRules)
|
||
return
|
||
}
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}))
|
||
mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts))
|
||
mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists))
|
||
// 传统查询接口(保持向后兼容)
|
||
mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery))
|
||
// RESTful 域名查询接口
|
||
mux.HandleFunc("/api/domains/", s.loginRequired(s.handleDomainQuery))
|
||
mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus))
|
||
mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig))
|
||
mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart))
|
||
// 添加统计相关接口
|
||
mux.HandleFunc("/api/top-blocked", s.loginRequired(s.handleTopBlockedDomains))
|
||
mux.HandleFunc("/api/top-resolved", s.loginRequired(s.handleTopResolvedDomains))
|
||
mux.HandleFunc("/api/top-clients", s.loginRequired(s.handleTopClients))
|
||
mux.HandleFunc("/api/top-domains", s.loginRequired(s.handleTopDomains))
|
||
mux.HandleFunc("/api/recent-blocked", s.loginRequired(s.handleRecentBlockedDomains))
|
||
mux.HandleFunc("/api/hourly-stats", s.loginRequired(s.handleHourlyStats))
|
||
mux.HandleFunc("/api/daily-stats", s.loginRequired(s.handleDailyStats))
|
||
mux.HandleFunc("/api/monthly-stats", s.loginRequired(s.handleMonthlyStats))
|
||
mux.HandleFunc("/api/query/type", s.loginRequired(s.handleQueryTypeStats))
|
||
// 日志统计相关接口
|
||
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
|
||
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
|
||
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
|
||
// 域名查询相关接口
|
||
mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo))
|
||
// 域名信息列表接口
|
||
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
|
||
// 威胁查询接口
|
||
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
|
||
// 威胁批量查询接口
|
||
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
|
||
// 威胁告警接口
|
||
mux.HandleFunc("/api/alert", s.loginRequired(s.handleAlert))
|
||
// 威胁告警解决接口
|
||
mux.HandleFunc("/api/alert/resolve", s.loginRequired(s.handleAlertResolve))
|
||
// 威胁域名管理接口
|
||
mux.HandleFunc("/api/threat/domain", s.loginRequired(s.handleThreatDomain))
|
||
// WebSocket 端点
|
||
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
|
||
|
||
// 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点
|
||
apiFileServer := http.FileServer(http.Dir("./static/api"))
|
||
mux.Handle("/api/", s.loginRequired(http.StripPrefix("/api", apiFileServer).ServeHTTP))
|
||
}
|
||
|
||
// 自定义静态文件服务处理器,用于禁用浏览器缓存,放在API路由之后
|
||
fileServer := http.FileServer(http.Dir("./static"))
|
||
|
||
// 单独处理login.html,不需要登录
|
||
mux.HandleFunc("/login.html", func(w http.ResponseWriter, r *http.Request) {
|
||
// 添加Cache-Control头,禁用浏览器缓存
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
w.Header().Set("Pragma", "no-cache")
|
||
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
|
||
// 直接提供login.html文件
|
||
http.ServeFile(w, r, "./static/login.html")
|
||
})
|
||
|
||
// Tracker目录静态文件服务
|
||
trackerFileServer := http.FileServer(http.Dir("./tracker"))
|
||
mux.HandleFunc("/tracker/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
|
||
// 添加Cache-Control头,禁用浏览器缓存
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
w.Header().Set("Pragma", "no-cache")
|
||
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
|
||
// 使用StripPrefix处理路径
|
||
http.StripPrefix("/tracker", trackerFileServer).ServeHTTP(w, r)
|
||
}))
|
||
|
||
// 其他静态文件需要登录
|
||
mux.HandleFunc("/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
|
||
// 添加Cache-Control头,禁用浏览器缓存
|
||
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
|
||
w.Header().Set("Pragma", "no-cache")
|
||
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
|
||
// 使用StripPrefix处理路径
|
||
http.StripPrefix("/", fileServer).ServeHTTP(w, r)
|
||
}))
|
||
|
||
s.server = &http.Server{
|
||
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
|
||
Handler: mux,
|
||
ReadTimeout: 10 * time.Second,
|
||
WriteTimeout: 10 * time.Second,
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port))
|
||
return s.server.ListenAndServe()
|
||
}
|
||
|
||
// Stop 停止HTTP服务器
|
||
func (s *Server) Stop() {
|
||
if s.server != nil {
|
||
s.server.Close()
|
||
}
|
||
logger.Info("HTTP控制台服务器已停止")
|
||
}
|
||
|
||
// handleStats 处理统计信息请求
|
||
// @Summary 获取系统统计信息
|
||
// @Description 获取DNS服务器和Shield的统计信息
|
||
// @Tags stats
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} map[string]interface{} "统计信息"
|
||
// @Failure 500 {object} map[string]string "服务器内部错误"
|
||
// @Router /api/stats [get]
|
||
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
dnsStats := s.dnsServer.GetStats()
|
||
shieldStats := s.shieldManager.GetStats()
|
||
|
||
// 获取最常用查询类型(如果有)
|
||
topQueryType := "-"
|
||
maxCount := int64(0)
|
||
if len(dnsStats.QueryTypes) > 0 {
|
||
for queryType, count := range dnsStats.QueryTypes {
|
||
if count > maxCount {
|
||
maxCount = count
|
||
topQueryType = queryType
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取活跃来源IP数量
|
||
activeIPCount := len(dnsStats.SourceIPs)
|
||
|
||
// 格式化平均响应时间为两位小数
|
||
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
|
||
|
||
// 计算DNSSEC使用率
|
||
dnssecUsage := float64(0)
|
||
if dnsStats.Queries > 0 {
|
||
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
|
||
}
|
||
|
||
// 构建响应数据,确保所有字段都反映服务器的真实状态
|
||
stats := map[string]interface{}{
|
||
"dns": map[string]interface{}{
|
||
"Queries": dnsStats.Queries,
|
||
"Blocked": dnsStats.Blocked,
|
||
"Allowed": dnsStats.Allowed,
|
||
"Errors": dnsStats.Errors,
|
||
"LastQuery": dnsStats.LastQuery,
|
||
"AvgResponseTime": formattedResponseTime,
|
||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||
"QueryTypes": dnsStats.QueryTypes,
|
||
"SourceIPs": dnsStats.SourceIPs,
|
||
"CpuUsage": dnsStats.CpuUsage,
|
||
"DNSSECQueries": dnsStats.DNSSECQueries,
|
||
"DNSSECSuccess": dnsStats.DNSSECSuccess,
|
||
"DNSSECFailed": dnsStats.DNSSECFailed,
|
||
"DNSSECEnabled": dnsStats.DNSSECEnabled,
|
||
},
|
||
"shield": shieldStats,
|
||
"topQueryType": topQueryType,
|
||
"activeIPs": activeIPCount,
|
||
"avgResponseTime": formattedResponseTime,
|
||
"cpuUsage": dnsStats.CpuUsage,
|
||
"dnssecEnabled": dnsStats.DNSSECEnabled,
|
||
"dnssecQueries": dnsStats.DNSSECQueries,
|
||
"dnssecSuccess": dnsStats.DNSSECSuccess,
|
||
"dnssecFailed": dnsStats.DNSSECFailed,
|
||
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
|
||
"time": time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(stats)
|
||
}
|
||
|
||
// WebSocket相关方法
|
||
|
||
// handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据
|
||
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
|
||
// 升级HTTP连接为WebSocket
|
||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
|
||
return
|
||
}
|
||
defer conn.Close()
|
||
|
||
// 将新客户端添加到客户端列表
|
||
s.clientsMutex.Lock()
|
||
s.clients[conn] = true
|
||
clientCount := len(s.clients)
|
||
s.clientsMutex.Unlock()
|
||
|
||
logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount))
|
||
|
||
// 发送初始数据
|
||
if err := s.sendInitialStats(conn); err != nil {
|
||
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 定期发送更新数据
|
||
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
|
||
defer ticker.Stop()
|
||
|
||
// 最后一次发送的数据快照,用于检测变化
|
||
var lastStats map[string]interface{}
|
||
|
||
// 保持连接并定期发送数据
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
// 获取最新统计数据
|
||
currentStats := s.buildStatsData()
|
||
|
||
// 检查数据是否有变化
|
||
if !s.areStatsEqual(lastStats, currentStats) {
|
||
// 数据有变化,发送更新
|
||
data, err := json.Marshal(map[string]interface{}{
|
||
"type": "stats_update",
|
||
"data": currentStats,
|
||
"time": time.Now(),
|
||
})
|
||
if err != nil {
|
||
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
|
||
continue
|
||
}
|
||
|
||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 更新最后发送的数据
|
||
lastStats = currentStats
|
||
}
|
||
case <-r.Context().Done():
|
||
// 客户端断开连接
|
||
s.clientsMutex.Lock()
|
||
delete(s.clients, conn)
|
||
clientCount := len(s.clients)
|
||
s.clientsMutex.Unlock()
|
||
logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount))
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// sendInitialStats 发送初始统计数据
|
||
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
|
||
stats := s.buildStatsData()
|
||
data, err := json.Marshal(map[string]interface{}{
|
||
"type": "initial_data",
|
||
"data": stats,
|
||
"time": time.Now(),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return conn.WriteMessage(websocket.TextMessage, data)
|
||
}
|
||
|
||
// buildStatsData 构建统计数据
|
||
func (s *Server) buildStatsData() map[string]interface{} {
|
||
dnsStats := s.dnsServer.GetStats()
|
||
shieldStats := s.shieldManager.GetStats()
|
||
|
||
// 获取最常用查询类型
|
||
topQueryType := "-"
|
||
maxCount := int64(0)
|
||
if len(dnsStats.QueryTypes) > 0 {
|
||
for queryType, count := range dnsStats.QueryTypes {
|
||
if count > maxCount {
|
||
maxCount = count
|
||
topQueryType = queryType
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取活跃来源IP数量
|
||
activeIPCount := len(dnsStats.SourceIPs)
|
||
|
||
// 格式化平均响应时间
|
||
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
|
||
|
||
// 计算DNSSEC使用率
|
||
dnssecUsage := float64(0)
|
||
if dnsStats.Queries > 0 {
|
||
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
|
||
}
|
||
|
||
return map[string]interface{}{
|
||
"dns": map[string]interface{}{
|
||
"Queries": dnsStats.Queries,
|
||
"Blocked": dnsStats.Blocked,
|
||
"Allowed": dnsStats.Allowed,
|
||
"Errors": dnsStats.Errors,
|
||
"LastQuery": dnsStats.LastQuery,
|
||
"AvgResponseTime": formattedResponseTime,
|
||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||
"QueryTypes": dnsStats.QueryTypes,
|
||
"SourceIPs": dnsStats.SourceIPs,
|
||
"CpuUsage": dnsStats.CpuUsage,
|
||
"DNSSECQueries": dnsStats.DNSSECQueries,
|
||
"DNSSECSuccess": dnsStats.DNSSECSuccess,
|
||
"DNSSECFailed": dnsStats.DNSSECFailed,
|
||
"DNSSECEnabled": dnsStats.DNSSECEnabled,
|
||
},
|
||
"shield": shieldStats,
|
||
"topQueryType": topQueryType,
|
||
"activeIPs": activeIPCount,
|
||
"avgResponseTime": formattedResponseTime,
|
||
"cpuUsage": dnsStats.CpuUsage,
|
||
"dnssecEnabled": dnsStats.DNSSECEnabled,
|
||
"dnssecQueries": dnsStats.DNSSECQueries,
|
||
"dnssecSuccess": dnsStats.DNSSECSuccess,
|
||
"dnssecFailed": dnsStats.DNSSECFailed,
|
||
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
|
||
}
|
||
}
|
||
|
||
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
|
||
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
|
||
if stats1 == nil || stats2 == nil {
|
||
return false
|
||
}
|
||
|
||
// 只比较关键数值,避免频繁更新
|
||
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
|
||
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
|
||
// 检查主要计数器
|
||
if dns1["Queries"] != dns2["Queries"] ||
|
||
dns1["Blocked"] != dns2["Blocked"] ||
|
||
dns1["Allowed"] != dns2["Allowed"] ||
|
||
dns1["Errors"] != dns2["Errors"] ||
|
||
dns1["AvgResponseTime"] != dns2["AvgResponseTime"] {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// startBroadcastLoop 启动广播循环
|
||
func (s *Server) startBroadcastLoop() {
|
||
for message := range s.broadcastChan {
|
||
s.clientsMutex.Lock()
|
||
for client := range s.clients {
|
||
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
|
||
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
|
||
client.Close()
|
||
delete(s.clients, client)
|
||
}
|
||
}
|
||
s.clientsMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// cleanupSessionsLoop 定期清理过期会话
|
||
func (s *Server) cleanupSessionsLoop() {
|
||
for {
|
||
time.Sleep(1 * time.Hour) // 每小时清理一次
|
||
s.sessionsMutex.Lock()
|
||
now := time.Now()
|
||
for sessionID, expiryTime := range s.sessions {
|
||
if now.After(expiryTime) {
|
||
delete(s.sessions, sessionID)
|
||
}
|
||
}
|
||
s.sessionsMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// isAuthenticated 检查用户是否已认证
|
||
func (s *Server) isAuthenticated(r *http.Request) bool {
|
||
// 从Cookie中获取会话ID
|
||
cookie, err := r.Cookie("session_id")
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
sessionID := cookie.Value
|
||
s.sessionsMutex.Lock()
|
||
defer s.sessionsMutex.Unlock()
|
||
|
||
// 检查会话是否存在且未过期
|
||
expiryTime, exists := s.sessions[sessionID]
|
||
if !exists {
|
||
return false
|
||
}
|
||
|
||
if time.Now().After(expiryTime) {
|
||
// 会话已过期,删除它
|
||
delete(s.sessions, sessionID)
|
||
return false
|
||
}
|
||
|
||
// 延长会话有效期
|
||
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
|
||
return true
|
||
}
|
||
|
||
// loginRequired 登录中间件
|
||
func (s *Server) loginRequired(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
// 检查是否为登录页面或登录API,允许直接访问
|
||
if r.URL.Path == "/login" || r.URL.Path == "/api/login" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
// 检查是否已认证
|
||
if !s.isAuthenticated(r) {
|
||
// 如果是API请求,返回401错误
|
||
if strings.HasPrefix(r.URL.Path, "/api/") || strings.HasPrefix(r.URL.Path, "/ws/") {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "未授权访问"})
|
||
return
|
||
}
|
||
// 否则重定向到登录页面
|
||
http.Redirect(w, r, "/login", http.StatusFound)
|
||
return
|
||
}
|
||
|
||
// 已认证,继续处理请求
|
||
next.ServeHTTP(w, r)
|
||
}
|
||
}
|
||
|
||
// handleTopBlockedDomains 处理TOP屏蔽域名请求
|
||
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 返回最近30天的所有域名,设置合理上限50
|
||
domains := s.dnsServer.GetTopBlockedDomains(50)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"count": domain.Count,
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleTopResolvedDomains 处理获取最常解析的域名请求
|
||
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domains := s.dnsServer.GetTopResolvedDomains(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"count": domain.Count,
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleRecentBlockedDomains 处理最近屏蔽域名请求
|
||
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domains := s.dnsServer.GetRecentBlockedDomains(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(domains))
|
||
for i, domain := range domains {
|
||
result[i] = map[string]interface{}{
|
||
"domain": domain.Domain,
|
||
"time": time.Unix(domain.LastSeen, 0).Format("15:04:05"),
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleHourlyStats 处理24小时统计请求
|
||
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
hourlyStats := s.dnsServer.GetHourlyStats()
|
||
|
||
// 生成最近24小时的数据
|
||
now := time.Now()
|
||
labels := make([]string, 24)
|
||
data := make([]int64, 24)
|
||
|
||
for i := 23; i >= 0; i-- {
|
||
hour := now.Add(time.Duration(-i) * time.Hour)
|
||
hourKey := hour.Format("2006-01-02-15")
|
||
labels[23-i] = hour.Format("15:00")
|
||
data[23-i] = hourlyStats[hourKey]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleDailyStats 处理每日统计数据请求
|
||
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取每日统计数据
|
||
dailyStats := s.dnsServer.GetDailyStats()
|
||
|
||
// 生成过去7天的时间标签
|
||
labels := make([]string, 7)
|
||
data := make([]int64, 7)
|
||
now := time.Now()
|
||
|
||
for i := 6; i >= 0; i-- {
|
||
t := now.AddDate(0, 0, -i)
|
||
key := t.Format("2006-01-02")
|
||
labels[6-i] = t.Format("01-02")
|
||
data[6-i] = dailyStats[key]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleMonthlyStats 处理每月统计数据请求
|
||
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取每日统计数据(用于30天视图)
|
||
dailyStats := s.dnsServer.GetDailyStats()
|
||
|
||
// 生成过去30天的时间标签
|
||
labels := make([]string, 30)
|
||
data := make([]int64, 30)
|
||
now := time.Now()
|
||
|
||
for i := 29; i >= 0; i-- {
|
||
t := now.AddDate(0, 0, -i)
|
||
key := t.Format("2006-01-02")
|
||
labels[29-i] = t.Format("01-02")
|
||
data[29-i] = dailyStats[key]
|
||
}
|
||
|
||
result := map[string]interface{}{
|
||
"labels": labels,
|
||
"data": data,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleQueryTypeStats 处理查询类型统计请求
|
||
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取DNS统计数据
|
||
dnsStats := s.dnsServer.GetStats()
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
|
||
for queryType, count := range dnsStats.QueryTypes {
|
||
result = append(result, map[string]interface{}{
|
||
"type": queryType,
|
||
"count": count,
|
||
})
|
||
}
|
||
|
||
// 按计数降序排序
|
||
sort.Slice(result, func(i, j int) bool {
|
||
return result[i]["count"].(int64) > result[j]["count"].(int64)
|
||
})
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleTopClients 处理TOP客户端请求
|
||
func (s *Server) handleTopClients(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取TOP客户端列表
|
||
clients := s.dnsServer.GetTopClients(10)
|
||
|
||
// 转换为前端需要的格式
|
||
result := make([]map[string]interface{}, len(clients))
|
||
for i, client := range clients {
|
||
result[i] = map[string]interface{}{
|
||
"ip": client.IP,
|
||
"count": client.Count,
|
||
"lastSeen": client.LastSeen,
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleTopDomains 处理TOP域名请求
|
||
func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取TOP被屏蔽域名,返回最近30天的数据,设置合理上限50
|
||
blockedDomains := s.dnsServer.GetTopBlockedDomains(50)
|
||
// 获取TOP已解析域名,返回最近30天的数据,设置合理上限50
|
||
resolvedDomains := s.dnsServer.GetTopResolvedDomains(50)
|
||
|
||
// 合并并去重域名统计
|
||
domainMap := make(map[string]int64)
|
||
dnssecStatusMap := make(map[string]bool)
|
||
|
||
for _, domain := range blockedDomains {
|
||
domainMap[domain.Domain] += domain.Count
|
||
dnssecStatusMap[domain.Domain] = domain.DNSSEC
|
||
}
|
||
for _, domain := range resolvedDomains {
|
||
domainMap[domain.Domain] += domain.Count
|
||
dnssecStatusMap[domain.Domain] = domain.DNSSEC
|
||
}
|
||
|
||
// 转换为切片并排序
|
||
domainList := make([]map[string]interface{}, 0, len(domainMap))
|
||
for domain, count := range domainMap {
|
||
dnssec, hasDNSSEC := dnssecStatusMap[domain]
|
||
domainList = append(domainList, map[string]interface{}{
|
||
"domain": domain,
|
||
"count": count,
|
||
"dnssec": hasDNSSEC && dnssec,
|
||
})
|
||
}
|
||
|
||
// 按计数降序排序
|
||
sort.Slice(domainList, func(i, j int) bool {
|
||
return domainList[i]["count"].(int64) > domainList[j]["count"].(int64)
|
||
})
|
||
|
||
// 返回所有合并后的域名,设置合理上限50
|
||
if len(domainList) > 50 {
|
||
domainList = domainList[:50]
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(domainList)
|
||
}
|
||
|
||
// handleShield 处理Shield相关操作
|
||
// @Summary 管理Shield配置
|
||
// @Description 获取或更新Shield的配置信息
|
||
// @Tags shield
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param config body map[string]interface{} false "Shield配置信息"
|
||
// @Success 200 {object} map[string]interface{} "配置信息"
|
||
// @Failure 400 {object} map[string]string "请求参数错误"
|
||
// @Failure 500 {object} map[string]string "服务器内部错误"
|
||
// @Router /api/shield [get]
|
||
// @Router /api/shield [post]
|
||
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
// 默认处理逻辑
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 检查是否需要返回完整规则列表
|
||
if r.URL.Query().Get("all") == "true" {
|
||
// 返回完整规则数据
|
||
rules := s.shieldManager.GetRules()
|
||
json.NewEncoder(w).Encode(rules)
|
||
return
|
||
}
|
||
// 获取规则统计信息
|
||
stats := s.shieldManager.GetStats()
|
||
shieldInfo := map[string]interface{}{
|
||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
|
||
"domainRulesCount": stats["domainRules"],
|
||
"domainExceptionsCount": stats["domainExceptions"],
|
||
"regexRulesCount": stats["regexRules"],
|
||
"regexExceptionsCount": stats["regexExceptions"],
|
||
"hostsRulesCount": stats["hostsRules"],
|
||
}
|
||
json.NewEncoder(w).Encode(shieldInfo)
|
||
return
|
||
}
|
||
switch r.Method {
|
||
case http.MethodPost:
|
||
// 添加屏蔽规则
|
||
var req struct {
|
||
Rule string `json:"rule"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.AddRule(req.Rule); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 清空DNS缓存,使新规则立即生效
|
||
s.dnsServer.DnsCache.Clear()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
case http.MethodDelete:
|
||
// 删除屏蔽规则
|
||
var req struct {
|
||
Rule string `json:"rule"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 清空DNS缓存,使规则变更立即生效
|
||
s.dnsServer.DnsCache.Clear()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
case http.MethodPut:
|
||
// 重新加载规则
|
||
if err := s.shieldManager.LoadRules(); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
|
||
return
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
}
|
||
|
||
// handleShieldBlacklists 处理黑名单相关操作
|
||
// @Summary 管理黑名单
|
||
// @Description 处理黑名单的CRUD操作,包括获取列表、添加、更新和删除黑名单
|
||
// @Tags shield
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param name path string false "黑名单名称(用于删除操作)"
|
||
// @Param blacklist body map[string]interface{} false "黑名单信息(用于添加/更新操作)"
|
||
// @Success 200 {object} map[string]interface{} "操作成功"
|
||
// @Failure 400 {object} map[string]string "请求参数错误"
|
||
// @Failure 404 {object} map[string]string "黑名单不存在"
|
||
// @Failure 500 {object} map[string]string "服务器内部错误"
|
||
// @Router /api/shield/blacklists [get]
|
||
// @Router /api/shield/blacklists [post]
|
||
// @Router /api/shield/blacklists [put]
|
||
// @Router /api/shield/blacklists/{name} [delete]
|
||
func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
// 处理更新单个黑名单
|
||
if strings.Contains(r.URL.Path, "/update") {
|
||
if r.Method == http.MethodPost {
|
||
// 提取黑名单URL或Name
|
||
parts := strings.Split(r.URL.Path, "/")
|
||
var targetURLOrName string
|
||
for i, part := range parts {
|
||
if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" {
|
||
targetURLOrName = parts[i+1]
|
||
break
|
||
}
|
||
}
|
||
|
||
if targetURLOrName == "" {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单标识不能为空"})
|
||
return
|
||
}
|
||
|
||
// 获取黑名单列表
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
var targetIndex = -1
|
||
for i, list := range blacklists {
|
||
if list.URL == targetURLOrName || list.Name == targetURLOrName {
|
||
targetIndex = i
|
||
break
|
||
}
|
||
}
|
||
|
||
if targetIndex == -1 {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单不存在"})
|
||
return
|
||
}
|
||
|
||
// 更新时间戳
|
||
blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339)
|
||
// 保存更新后的黑名单列表
|
||
s.shieldManager.UpdateBlacklist(blacklists)
|
||
// 更新全局配置中的黑名单
|
||
s.globalConfig.Shield.Blacklists = blacklists
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||
logger.Error("保存配置文件失败", "error", err)
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||
return
|
||
}
|
||
// 重新加载规则以获取最新的远程规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 处理删除黑名单
|
||
parts := strings.Split(r.URL.Path, "/")
|
||
if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete {
|
||
id := parts[4]
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
var newBlacklists []config.BlacklistEntry
|
||
|
||
for _, list := range blacklists {
|
||
if list.URL != id && list.Name != id {
|
||
newBlacklists = append(newBlacklists, list)
|
||
}
|
||
}
|
||
|
||
s.shieldManager.UpdateBlacklist(newBlacklists)
|
||
// 更新全局配置中的黑名单
|
||
s.globalConfig.Shield.Blacklists = newBlacklists
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||
logger.Error("保存配置文件失败", "error", err)
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||
return
|
||
}
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
return
|
||
}
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 获取远程黑名单列表
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
json.NewEncoder(w).Encode(blacklists)
|
||
|
||
case http.MethodPost:
|
||
// 添加远程黑名单
|
||
var req struct {
|
||
Name string `json:"name"`
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
|
||
return
|
||
}
|
||
|
||
if req.Name == "" || req.URL == "" {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "名称和URL不能为空"})
|
||
return
|
||
}
|
||
|
||
// 获取现有黑名单
|
||
blacklists := s.shieldManager.GetBlacklists()
|
||
|
||
// 检查是否已存在
|
||
for _, list := range blacklists {
|
||
if list.URL == req.URL {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单URL已存在"})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 检查URL是否存在且可访问
|
||
if !checkURLExists(req.URL) {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "URL不存在或无法访问"})
|
||
return
|
||
}
|
||
|
||
// 添加新黑名单
|
||
newEntry := config.BlacklistEntry{
|
||
Name: req.Name,
|
||
URL: req.URL,
|
||
Enabled: true,
|
||
}
|
||
|
||
blacklists = append(blacklists, newEntry)
|
||
s.shieldManager.UpdateBlacklist(blacklists)
|
||
// 更新全局配置中的黑名单
|
||
s.globalConfig.Shield.Blacklists = blacklists
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||
logger.Error("保存配置文件失败", "error", err)
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||
return
|
||
}
|
||
|
||
// 重新加载规则以获取新添加的远程规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodPut:
|
||
// 更新黑名单列表(包括启用/禁用状态)
|
||
var updatedBlacklists []struct {
|
||
Name string `json:"Name" json:"name"`
|
||
URL string `json:"URL" json:"url"`
|
||
Enabled bool `json:"Enabled" json:"enabled"`
|
||
LastUpdateTime string `json:"LastUpdateTime" json:"lastUpdateTime"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&updatedBlacklists); err != nil {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
|
||
return
|
||
}
|
||
|
||
// 转换为config.BlacklistEntry类型
|
||
var newBlacklists []config.BlacklistEntry
|
||
for _, entry := range updatedBlacklists {
|
||
newBlacklists = append(newBlacklists, config.BlacklistEntry{
|
||
Name: entry.Name,
|
||
URL: entry.URL,
|
||
Enabled: entry.Enabled,
|
||
LastUpdateTime: entry.LastUpdateTime,
|
||
})
|
||
}
|
||
|
||
// 更新黑名单
|
||
s.shieldManager.UpdateBlacklist(newBlacklists)
|
||
// 更新全局配置中的黑名单
|
||
s.globalConfig.Shield.Blacklists = newBlacklists
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
|
||
logger.Error("保存配置文件失败", "error", err)
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
|
||
return
|
||
}
|
||
// 重新加载规则
|
||
s.shieldManager.LoadRules()
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
default:
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "Method not allowed"})
|
||
}
|
||
}
|
||
|
||
// handleShieldHosts 处理hosts管理请求
|
||
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodPost:
|
||
// 添加hosts条目
|
||
var req struct {
|
||
IP string `json:"ip"`
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.IP == "" || req.Domain == "" {
|
||
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodDelete:
|
||
// 删除hosts条目
|
||
var req struct {
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodGet:
|
||
// 获取hosts条目列表
|
||
hosts := s.shieldManager.GetAllHosts()
|
||
hostsCount := s.shieldManager.GetHostsCount()
|
||
|
||
// 转换为数组格式,便于前端展示
|
||
hostsList := make([]map[string]string, 0, len(hosts))
|
||
for domain, ip := range hosts {
|
||
hostsList = append(hostsList, map[string]string{
|
||
"domain": domain,
|
||
"ip": ip,
|
||
})
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"hosts": hostsList,
|
||
"hostsCount": hostsCount,
|
||
})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// handleQuery 处理DNS查询请求(传统接口,保持向后兼容)
|
||
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
domain := r.URL.Query().Get("domain")
|
||
if domain == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||
return
|
||
}
|
||
|
||
// 获取域名屏蔽的详细信息
|
||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||
|
||
// 添加时间戳
|
||
blockDetails["timestamp"] = time.Now()
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(blockDetails)
|
||
}
|
||
|
||
// handleDomainQuery 处理RESTful风格的域名查询请求
|
||
func (s *Server) handleDomainQuery(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 从URL路径中提取域名参数
|
||
// 路径格式: /api/domains/{domain}
|
||
path := r.URL.Path
|
||
parts := strings.Split(path, "/")
|
||
if len(parts) < 4 {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||
return
|
||
}
|
||
|
||
domain := parts[3]
|
||
if domain == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
|
||
return
|
||
}
|
||
|
||
// 获取域名屏蔽的详细信息
|
||
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
|
||
|
||
// 构建RESTful风格的响应
|
||
response := map[string]interface{}{
|
||
"domain": domain,
|
||
"status": blockDetails["blocked"],
|
||
"timestamp": time.Now(),
|
||
"details": blockDetails,
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(response)
|
||
}
|
||
|
||
// handleStatus 处理系统状态请求
|
||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
stats := s.dnsServer.GetStats()
|
||
|
||
// 使用服务器的实际启动时间计算准确的运行时间
|
||
serverStartTime := s.dnsServer.GetStartTime()
|
||
uptime := time.Since(serverStartTime)
|
||
|
||
// 构建包含所有真实服务器统计数据的响应
|
||
status := map[string]interface{}{
|
||
"status": "running",
|
||
"queries": stats.Queries,
|
||
"blocked": stats.Blocked,
|
||
"allowed": stats.Allowed,
|
||
"errors": stats.Errors,
|
||
"lastQuery": stats.LastQuery,
|
||
"avgResponseTime": stats.AvgResponseTime,
|
||
"activeIPs": len(stats.SourceIPs),
|
||
"startTime": serverStartTime,
|
||
"uptime": uptime.Milliseconds(), // 转换为毫秒数,方便前端处理
|
||
"cpuUsage": stats.CpuUsage,
|
||
"timestamp": time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(status)
|
||
}
|
||
|
||
// saveConfigToFile 保存配置到文件
|
||
func saveConfigToFile(config *config.Config, filePath string) error {
|
||
// 创建新的INI文件
|
||
cfg := ini.Empty()
|
||
|
||
// DNS配置
|
||
dnsSection := cfg.Section("dns")
|
||
dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port))
|
||
dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", "))
|
||
dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", "))
|
||
dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval))
|
||
dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL))
|
||
dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC))
|
||
dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode)
|
||
dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout))
|
||
dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn))
|
||
dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", "))
|
||
dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6))
|
||
dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode)
|
||
dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize))
|
||
dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL))
|
||
dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL))
|
||
|
||
// 域名特定DNS服务器配置
|
||
for domain, servers := range config.DNS.DomainSpecificDNS {
|
||
dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", "))
|
||
}
|
||
|
||
// HTTP配置
|
||
httpSection := cfg.Section("http")
|
||
httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port))
|
||
httpSection.Key("host").SetValue(config.HTTP.Host)
|
||
httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI))
|
||
httpSection.Key("username").SetValue(config.HTTP.Username)
|
||
httpSection.Key("password").SetValue(config.HTTP.Password)
|
||
|
||
// Shield配置
|
||
shieldSection := cfg.Section("shield")
|
||
shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval))
|
||
shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod)
|
||
shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP)
|
||
shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval))
|
||
|
||
// 黑名单配置
|
||
for _, bl := range config.Shield.Blacklists {
|
||
shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled))
|
||
}
|
||
|
||
// GFWList配置
|
||
gfwListSection := cfg.Section("gfwList")
|
||
gfwListSection.Key("ip").SetValue(config.GFWList.IP)
|
||
gfwListSection.Key("content").SetValue(config.GFWList.Content)
|
||
gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled))
|
||
|
||
// Log配置
|
||
logSection := cfg.Section("log")
|
||
logSection.Key("level").SetValue(config.Log.Level)
|
||
logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize))
|
||
logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups))
|
||
logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge))
|
||
|
||
// Threat配置
|
||
threatSection := cfg.Section("threat")
|
||
threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled))
|
||
threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold))
|
||
threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold))
|
||
threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength))
|
||
threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ","))
|
||
threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ","))
|
||
threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays))
|
||
threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath)
|
||
|
||
// 保存到文件
|
||
return cfg.SaveTo(filePath)
|
||
}
|
||
|
||
// handleConfig 处理配置请求
|
||
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 每次从配置文件重新读取最新配置
|
||
cfg, err := config.LoadConfig("config.ini")
|
||
if err != nil {
|
||
logger.Error("加载配置文件失败", "error", err)
|
||
// 如果加载失败,返回内存中的配置
|
||
cfg = s.globalConfig
|
||
}
|
||
|
||
// 返回当前配置(包括黑名单配置)
|
||
// 注意:key 名必须与前端期望的一致
|
||
config := map[string]interface{}{
|
||
"Shield": map[string]interface{}{
|
||
"blockMethod": cfg.Shield.BlockMethod,
|
||
"customBlockIP": cfg.Shield.CustomBlockIP,
|
||
"blacklists": cfg.Shield.Blacklists,
|
||
"updateInterval": cfg.Shield.UpdateInterval,
|
||
"statsSaveInterval": cfg.Shield.StatsSaveInterval,
|
||
},
|
||
"GFWList": map[string]interface{}{
|
||
"ip": cfg.GFWList.IP,
|
||
"content": cfg.GFWList.Content,
|
||
"enabled": cfg.GFWList.Enabled,
|
||
},
|
||
"DNSServer": map[string]interface{}{
|
||
"port": cfg.DNS.Port,
|
||
"QueryMode": cfg.DNS.QueryMode,
|
||
"UpstreamServers": cfg.DNS.UpstreamDNS,
|
||
"DNSSECUpstreamServers": cfg.DNS.DNSSECUpstreamDNS,
|
||
"saveInterval": cfg.DNS.SaveInterval,
|
||
"queryTimeout": cfg.DNS.QueryTimeout,
|
||
"enableIPv6": cfg.DNS.EnableIPv6,
|
||
"enableDNSSEC": cfg.DNS.EnableDNSSEC,
|
||
"enableFastReturn": cfg.DNS.EnableFastReturn,
|
||
"noDNSSECDomains": cfg.DNS.NoDNSSECDomains,
|
||
"CacheMode": cfg.DNS.CacheMode,
|
||
"CacheSize": cfg.DNS.CacheSize,
|
||
"MaxCacheTTL": cfg.DNS.MaxCacheTTL,
|
||
"MinCacheTTL": cfg.DNS.MinCacheTTL,
|
||
"domainSpecificDNS": cfg.DNS.DomainSpecificDNS,
|
||
},
|
||
"HTTPServer": map[string]interface{}{
|
||
"port": cfg.HTTP.Port,
|
||
"host": cfg.HTTP.Host,
|
||
"enableAPI": cfg.HTTP.EnableAPI,
|
||
"username": cfg.HTTP.Username,
|
||
"password": cfg.HTTP.Password,
|
||
},
|
||
}
|
||
json.NewEncoder(w).Encode(config)
|
||
|
||
case http.MethodPost:
|
||
// 更新配置
|
||
var req struct {
|
||
DNSServer struct {
|
||
Port int `json:"port"`
|
||
QueryMode string `json:"queryMode"`
|
||
UpstreamServers []string `json:"upstreamServers"`
|
||
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
|
||
Timeout int `json:"timeout"`
|
||
SaveInterval int `json:"saveInterval"`
|
||
EnableIPv6 bool `json:"enableIPv6"`
|
||
EnableDNSSEC bool `json:"enableDNSSEC"`
|
||
EnableFastReturn *bool `json:"enableFastReturn"`
|
||
NoDNSSECDomains []string `json:"noDNSSECDomains"`
|
||
CacheMode string `json:"cacheMode"`
|
||
CacheSize int `json:"cacheSize"`
|
||
MaxCacheTTL int `json:"maxCacheTTL"`
|
||
MinCacheTTL int `json:"minCacheTTL"`
|
||
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
|
||
} `json:"dnsserver"`
|
||
HTTPServer struct {
|
||
Port int `json:"port"`
|
||
Host string `json:"host"`
|
||
EnableAPI bool `json:"enableAPI"`
|
||
Username string `json:"username"`
|
||
Password string `json:"password"`
|
||
} `json:"httpserver"`
|
||
Shield struct {
|
||
BlockMethod string `json:"blockMethod"`
|
||
CustomBlockIP string `json:"customBlockIP"`
|
||
Blacklists []config.BlacklistEntry `json:"blacklists"`
|
||
UpdateInterval int `json:"updateInterval"`
|
||
StatsSaveInterval int `json:"statsSaveInterval"`
|
||
} `json:"shield"`
|
||
GFWList struct {
|
||
IP string `json:"ip"`
|
||
Content string `json:"content"`
|
||
Enabled bool `json:"enabled"`
|
||
} `json:"gfwList"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "无效的请求体", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 更新DNS配置
|
||
if req.DNSServer.Port > 0 {
|
||
s.globalConfig.DNS.Port = req.DNSServer.Port
|
||
}
|
||
if len(req.DNSServer.UpstreamServers) > 0 {
|
||
s.globalConfig.DNS.UpstreamDNS = req.DNSServer.UpstreamServers
|
||
}
|
||
if len(req.DNSServer.DnssecUpstreamServers) > 0 {
|
||
s.globalConfig.DNS.DNSSECUpstreamDNS = req.DNSServer.DnssecUpstreamServers
|
||
}
|
||
if req.DNSServer.SaveInterval > 0 {
|
||
s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval
|
||
}
|
||
if req.DNSServer.Timeout > 0 {
|
||
s.globalConfig.DNS.QueryTimeout = req.DNSServer.Timeout
|
||
}
|
||
s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6
|
||
s.globalConfig.DNS.EnableDNSSEC = req.DNSServer.EnableDNSSEC
|
||
// 更新查询模式
|
||
if req.DNSServer.QueryMode != "" {
|
||
s.globalConfig.DNS.QueryMode = req.DNSServer.QueryMode
|
||
}
|
||
// 更新缓存配置
|
||
if req.DNSServer.CacheMode != "" {
|
||
s.globalConfig.DNS.CacheMode = req.DNSServer.CacheMode
|
||
}
|
||
if req.DNSServer.CacheSize > 0 {
|
||
s.globalConfig.DNS.CacheSize = req.DNSServer.CacheSize
|
||
}
|
||
if req.DNSServer.MaxCacheTTL > 0 {
|
||
s.globalConfig.DNS.MaxCacheTTL = req.DNSServer.MaxCacheTTL
|
||
}
|
||
if req.DNSServer.MinCacheTTL > 0 {
|
||
s.globalConfig.DNS.MinCacheTTL = req.DNSServer.MinCacheTTL
|
||
}
|
||
// 更新enableFastReturn
|
||
if req.DNSServer.EnableFastReturn != nil {
|
||
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
|
||
}
|
||
// 更新noDNSSECDomains
|
||
if len(req.DNSServer.NoDNSSECDomains) > 0 {
|
||
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
|
||
}
|
||
// 更新domainSpecificDNS
|
||
if req.DNSServer.DomainSpecificDNS != nil {
|
||
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
|
||
}
|
||
|
||
// 更新HTTP配置
|
||
if req.HTTPServer.Port > 0 {
|
||
s.globalConfig.HTTP.Port = req.HTTPServer.Port
|
||
}
|
||
if req.HTTPServer.Host != "" {
|
||
s.globalConfig.HTTP.Host = req.HTTPServer.Host
|
||
}
|
||
s.globalConfig.HTTP.EnableAPI = req.HTTPServer.EnableAPI
|
||
if req.HTTPServer.Username != "" {
|
||
s.globalConfig.HTTP.Username = req.HTTPServer.Username
|
||
}
|
||
if req.HTTPServer.Password != "" {
|
||
s.globalConfig.HTTP.Password = req.HTTPServer.Password
|
||
}
|
||
|
||
// 更新屏蔽配置
|
||
if req.Shield.BlockMethod != "" {
|
||
// 验证屏蔽方法是否有效
|
||
validMethods := map[string]bool{
|
||
"NXDOMAIN": true,
|
||
"refused": true,
|
||
"emptyIP": true,
|
||
"customIP": true,
|
||
}
|
||
|
||
if !validMethods[req.Shield.BlockMethod] {
|
||
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
|
||
|
||
// 如果选择了customIP,验证IP地址
|
||
if req.Shield.BlockMethod == "customIP" {
|
||
if req.Shield.CustomBlockIP == "" {
|
||
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
|
||
return
|
||
}
|
||
// 简单的IP地址验证
|
||
if !isValidIP(req.Shield.CustomBlockIP) {
|
||
http.Error(w, "无效的IP地址", http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
if req.Shield.CustomBlockIP != "" {
|
||
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
|
||
}
|
||
|
||
// 更新更新间隔
|
||
if req.Shield.UpdateInterval > 0 {
|
||
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
|
||
// 重新启动自动更新
|
||
s.shieldManager.StopAutoUpdate()
|
||
s.shieldManager.StartAutoUpdate()
|
||
}
|
||
|
||
// 更新统计保存间隔
|
||
if req.Shield.StatsSaveInterval > 0 {
|
||
s.globalConfig.Shield.StatsSaveInterval = req.Shield.StatsSaveInterval
|
||
}
|
||
|
||
// 更新GFWList配置
|
||
s.globalConfig.GFWList.IP = req.GFWList.IP
|
||
s.globalConfig.GFWList.Content = req.GFWList.Content
|
||
s.globalConfig.GFWList.Enabled = req.GFWList.Enabled
|
||
|
||
// 重新加载GFWList规则
|
||
if s.gfwManager != nil {
|
||
if err := s.gfwManager.LoadRules(); err != nil {
|
||
logger.Error("重新加载GFWList规则失败", "error", err)
|
||
}
|
||
}
|
||
|
||
// 更新黑名单配置
|
||
if req.Shield.Blacklists != nil {
|
||
// 验证黑名单配置
|
||
for i, bl := range req.Shield.Blacklists {
|
||
if bl.URL == "" {
|
||
http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest)
|
||
return
|
||
}
|
||
if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") {
|
||
http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest)
|
||
return
|
||
}
|
||
}
|
||
s.globalConfig.Shield.Blacklists = req.Shield.Blacklists
|
||
s.shieldManager.UpdateBlacklist(req.Shield.Blacklists)
|
||
// 重新加载规则
|
||
if err := s.shieldManager.LoadRules(); err != nil {
|
||
logger.Error("重新加载规则失败", "error", err)
|
||
}
|
||
}
|
||
|
||
// 更新现有的DNSCache实例配置
|
||
// 最大和最小TTL(秒)
|
||
maxCacheTTL := time.Duration(s.globalConfig.DNS.MaxCacheTTL) * time.Second
|
||
minCacheTTL := time.Duration(s.globalConfig.DNS.MinCacheTTL) * time.Second
|
||
// 最大缓存大小(字节)
|
||
maxCacheSize := int64(s.globalConfig.DNS.CacheSize) * 1024 * 1024
|
||
|
||
// 更新缓存配置
|
||
s.dnsServer.DnsCache.SetMaxCacheTTL(maxCacheTTL)
|
||
s.dnsServer.DnsCache.SetMinCacheTTL(minCacheTTL)
|
||
s.dnsServer.DnsCache.SetCacheMode(s.globalConfig.DNS.CacheMode)
|
||
s.dnsServer.DnsCache.SetMaxCacheSize(maxCacheSize)
|
||
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "./config.ini"); err != nil {
|
||
logger.Error("保存配置到文件失败", "error", err)
|
||
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
|
||
}
|
||
|
||
// 返回成功响应
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"success": true,
|
||
"message": "配置已更新",
|
||
})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// isValidIP 简单验证IP地址格式
|
||
func isValidIP(ip string) bool {
|
||
// 简单的IPv4地址验证
|
||
parts := strings.Split(ip, ".")
|
||
if len(parts) != 4 {
|
||
return false
|
||
}
|
||
|
||
for _, part := range parts {
|
||
// 检查是否为数字
|
||
for _, char := range part {
|
||
if char < '0' || char > '9' {
|
||
return false
|
||
}
|
||
}
|
||
// 检查数字范围
|
||
var num int
|
||
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
// checkURLExists 检查URL是否存在且可访问
|
||
func checkURLExists(url string) bool {
|
||
// 创建一个带有超时的HTTP客户端
|
||
client := &http.Client{
|
||
Timeout: 5 * time.Second,
|
||
}
|
||
|
||
// 发送HEAD请求来检查URL是否存在
|
||
resp, err := client.Head(url)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查状态码,2xx和3xx表示成功
|
||
return resp.StatusCode >= 200 && resp.StatusCode < 400
|
||
}
|
||
|
||
// handleLogsStats 处理日志统计请求
|
||
func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 构建缓存键
|
||
cacheKey := "logs_stats"
|
||
|
||
// 如果启用缓存,先尝试从缓存获取
|
||
if s.cacheEnabled {
|
||
if cachedStats, found := s.statsCache.Get(cacheKey); found {
|
||
// 缓存命中,直接返回
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(cachedStats)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 缓存未命中,获取最新统计数据
|
||
logStats := s.dnsServer.GetQueryStats()
|
||
|
||
// 存入缓存
|
||
if s.cacheEnabled {
|
||
s.statsCache.Set(cacheKey, logStats)
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(logStats)
|
||
}
|
||
|
||
// handleLogsQuery 处理日志查询请求
|
||
func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取查询参数
|
||
limit := 100 // 默认返回 100 条日志
|
||
offset := 0
|
||
sortField := r.URL.Query().Get("sort")
|
||
sortDirection := r.URL.Query().Get("direction")
|
||
resultFilter := r.URL.Query().Get("result")
|
||
searchTerm := r.URL.Query().Get("search")
|
||
queryType := r.URL.Query().Get("queryType")
|
||
|
||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||
fmt.Sscanf(limitStr, "%d", &limit)
|
||
}
|
||
|
||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||
fmt.Sscanf(offsetStr, "%d", &offset)
|
||
}
|
||
|
||
// 构建缓存键,包含所有查询参数
|
||
// 已禁用缓存,每次都从数据库获取最新数据
|
||
// cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||
|
||
// 缓存未命中,获取日志数据(已禁用缓存)
|
||
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
|
||
|
||
// 存入缓存(已禁用)
|
||
// if s.cacheEnabled {
|
||
// s.queryCache.Set(cacheKey, logs)
|
||
// }
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(logs)
|
||
}
|
||
|
||
// handleLogsCount 处理日志总数请求
|
||
func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取过滤参数
|
||
resultFilter := r.URL.Query().Get("result")
|
||
searchTerm := r.URL.Query().Get("search")
|
||
queryType := r.URL.Query().Get("queryType")
|
||
|
||
// 构建缓存键(已禁用)
|
||
// cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
|
||
|
||
// 缓存未命中,获取带过滤条件的日志总数(已禁用缓存)
|
||
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
|
||
|
||
// 存入缓存(已禁用)
|
||
// if s.cacheEnabled {
|
||
// s.queryCache.Set(cacheKey, count)
|
||
// }
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]int{"count": count})
|
||
}
|
||
|
||
// handleDomainInfo 处理域名信息查询请求
|
||
func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 解析请求体
|
||
var req struct {
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.Domain == "" {
|
||
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 从域名信息数据库中查询
|
||
domainInfo, err := domain.GetDomainInfo(req.Domain)
|
||
if err != nil {
|
||
http.Error(w, "Failed to query domain info", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 返回域名信息
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(domainInfo)
|
||
}
|
||
|
||
// handleRestart 处理重启服务请求
|
||
func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
logger.Info("收到重启服务请求")
|
||
|
||
// 停止DNS服务器
|
||
s.dnsServer.Stop()
|
||
|
||
// 重新加载屏蔽规则
|
||
if err := s.shieldManager.LoadRules(); err != nil {
|
||
logger.Error("重新加载屏蔽规则失败", "error", err)
|
||
}
|
||
|
||
// 重新启动DNS服务器
|
||
go func() {
|
||
if err := s.dnsServer.Start(); err != nil {
|
||
logger.Error("DNS服务器重启失败", "error", err)
|
||
}
|
||
}()
|
||
|
||
// 重新启动定时更新任务
|
||
s.shieldManager.StopAutoUpdate()
|
||
s.shieldManager.StartAutoUpdate()
|
||
|
||
// 返回成功响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "服务已重启"})
|
||
logger.Info("服务重启成功")
|
||
}
|
||
|
||
// handleLogin 处理登录请求
|
||
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 解析请求体
|
||
var loginData struct {
|
||
Username string `json:"username"`
|
||
Password string `json:"password"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&loginData); err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
|
||
return
|
||
}
|
||
|
||
// 验证用户名和密码
|
||
if loginData.Username != s.config.Username || loginData.Password != s.config.Password {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "用户名或密码错误"})
|
||
return
|
||
}
|
||
|
||
// 生成会话ID
|
||
sessionID := fmt.Sprintf("%d_%d", time.Now().UnixNano(), len(s.sessions))
|
||
|
||
// 保存会话
|
||
s.sessionsMutex.Lock()
|
||
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
|
||
s.sessionsMutex.Unlock()
|
||
|
||
// 设置Cookie
|
||
cookie := &http.Cookie{
|
||
Name: "session_id",
|
||
Value: sessionID,
|
||
Path: "/",
|
||
Expires: time.Now().Add(s.sessionTTL),
|
||
HttpOnly: true,
|
||
Secure: false, // 开发环境下使用false,生产环境应使用true
|
||
}
|
||
http.SetCookie(w, cookie)
|
||
|
||
// 返回成功响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "登录成功"})
|
||
logger.Info(fmt.Sprintf("用户 %s 登录成功", loginData.Username))
|
||
}
|
||
|
||
// handleLogout 处理注销请求
|
||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 从Cookie中获取会话ID
|
||
cookie, err := r.Cookie("session_id")
|
||
if err == nil {
|
||
// 删除会话
|
||
s.sessionsMutex.Lock()
|
||
delete(s.sessions, cookie.Value)
|
||
s.sessionsMutex.Unlock()
|
||
}
|
||
|
||
// 清除Cookie
|
||
clearCookie := &http.Cookie{
|
||
Name: "session_id",
|
||
Value: "",
|
||
Path: "/",
|
||
Expires: time.Unix(0, 0),
|
||
HttpOnly: true,
|
||
Secure: false,
|
||
}
|
||
http.SetCookie(w, clearCookie)
|
||
|
||
// 返回成功响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "注销成功"})
|
||
logger.Info("用户注销成功")
|
||
}
|
||
|
||
// handleDomainInfoList 处理域名信息列表请求
|
||
func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取查询参数
|
||
query := r.URL.Query()
|
||
|
||
if query.Has("domains") {
|
||
// 处理域名信息,支持过滤特定域名
|
||
domainFilter := query.Get("domains")
|
||
handleDomainsInfo(w, domainFilter)
|
||
} else if query.Has("trackers") {
|
||
// 处理跟踪器信息,支持过滤特定域名
|
||
trackerFilter := query.Get("trackers")
|
||
handleTrackersInfo(w, trackerFilter)
|
||
} else if query.Has("threats") {
|
||
// 处理威胁域名信息,支持过滤特定域名
|
||
threatFilter := query.Get("threats")
|
||
handleThreatsInfo(w, threatFilter)
|
||
} else {
|
||
// 直接访问 /domain-info 不提供任何内容
|
||
http.Error(w, "No content provided", http.StatusNoContent)
|
||
return
|
||
}
|
||
}
|
||
|
||
// isService 判断一个对象是否是服务(而不是分组)
|
||
func isService(obj map[string]interface{}) bool {
|
||
// 服务通常包含 name、url、categoryId 字段
|
||
_, hasName := obj["name"]
|
||
_, hasUrl := obj["url"]
|
||
_, hasCategoryId := obj["categoryId"]
|
||
|
||
// 如果有 name 和 url,则认为是服务
|
||
if hasName && hasUrl {
|
||
return true
|
||
}
|
||
|
||
// 如果有 categoryId,也认为是服务
|
||
if hasCategoryId {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// handleAlert 处理威胁告警请求
|
||
func (s *Server) handleAlert(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 获取告警列表
|
||
limit := 100
|
||
offset := 0
|
||
level := r.URL.Query().Get("level")
|
||
|
||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||
fmt.Sscanf(limitStr, "%d", &limit)
|
||
}
|
||
|
||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||
fmt.Sscanf(offsetStr, "%d", &offset)
|
||
}
|
||
|
||
// 获取告警列表
|
||
alerts := s.dnsServer.GetAlerts(limit, offset, level)
|
||
|
||
// 构建响应
|
||
response := map[string]interface{}{
|
||
"alerts": alerts,
|
||
"total": s.dnsServer.GetAlertCount(level),
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(response)
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// handleAlertResolve 处理威胁告警解决请求
|
||
func (s *Server) handleAlertResolve(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 解析请求体
|
||
var req struct {
|
||
AlertID string `json:"alertId"`
|
||
Action string `json:"action"` // blocked, allowed
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.AlertID == "" || req.Action == "" {
|
||
http.Error(w, "AlertID and Action are required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 验证动作
|
||
if req.Action != threat.ActionBlocked && req.Action != threat.ActionAllowed {
|
||
http.Error(w, "Invalid action", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 解决告警
|
||
success := s.dnsServer.ResolveAlert(req.AlertID, req.Action)
|
||
|
||
if success {
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
} else {
|
||
http.Error(w, "Failed to resolve alert", http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
// handleThreatDomain 处理威胁域名管理请求
|
||
func (s *Server) handleThreatDomain(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
// 获取所有威胁域名
|
||
threats := s.dnsServer.GetThreatDomains()
|
||
json.NewEncoder(w).Encode(threats)
|
||
|
||
case http.MethodPost:
|
||
// 添加威胁域名
|
||
var req struct {
|
||
Type string `json:"type"`
|
||
Name string `json:"name"`
|
||
RiskLevel int `json:"riskLevel"`
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.Domain == "" {
|
||
http.Error(w, "Domain is required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 设置默认值
|
||
if req.Type == "" {
|
||
req.Type = "未知"
|
||
}
|
||
if req.Name == "" {
|
||
req.Name = "未知"
|
||
}
|
||
if req.RiskLevel == 0 {
|
||
req.RiskLevel = 1
|
||
}
|
||
|
||
err := s.dnsServer.AddThreatDomain(req.Type, req.Name, req.RiskLevel, req.Domain)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
case http.MethodDelete:
|
||
// 删除威胁域名
|
||
var req struct {
|
||
Domain string `json:"domain"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
if req.Domain == "" {
|
||
http.Error(w, "Domain is required", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
err := s.dnsServer.RemoveThreatDomain(req.Domain)
|
||
if err != nil {
|
||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||
|
||
default:
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
}
|
||
}
|
||
|
||
// processServiceItem 递归处理服务或分组
|
||
func processServiceItem(
|
||
serviceName string,
|
||
service interface{},
|
||
companyLevelCompany string,
|
||
domainFilter string,
|
||
categories map[string]string,
|
||
result *[]map[string]interface{},
|
||
) {
|
||
serviceMap, ok := service.(map[string]interface{})
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
// 跳过 company 字段
|
||
if serviceName == "company" {
|
||
return
|
||
}
|
||
|
||
// 判断是服务还是分组
|
||
if isService(serviceMap) {
|
||
// 这是一个服务,进行处理
|
||
urlValue := serviceMap["url"]
|
||
match := false
|
||
|
||
// 检查是否需要过滤
|
||
if domainFilter != "" {
|
||
// 检查服务名称是否包含过滤条件
|
||
if serviceName == domainFilter {
|
||
match = true
|
||
} else {
|
||
// 检查 URL 是否包含过滤条件
|
||
switch v := urlValue.(type) {
|
||
case string:
|
||
if strings.Contains(v, domainFilter) {
|
||
match = true
|
||
}
|
||
case map[string]interface{}:
|
||
for _, url := range v {
|
||
if urlStr, ok := url.(string); ok && strings.Contains(urlStr, domainFilter) {
|
||
match = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !match {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 确定公司名:优先使用服务级别的 company 字段,否则使用公司级别的 company 字段
|
||
itemCompany := companyLevelCompany
|
||
if serviceCompany, ok := serviceMap["company"].(string); ok {
|
||
itemCompany = serviceCompany
|
||
}
|
||
|
||
// 构建响应对象
|
||
item := map[string]interface{}{
|
||
"icon": serviceMap["icon"],
|
||
"name": serviceMap["name"],
|
||
"company": itemCompany,
|
||
}
|
||
|
||
// 添加类别
|
||
if categoryId, ok := serviceMap["categoryId"].(float64); ok {
|
||
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
|
||
if category, exists := categories[categoryIdStr]; exists {
|
||
item["category"] = category
|
||
}
|
||
}
|
||
|
||
*result = append(*result, item)
|
||
} else {
|
||
// 这是一个分组,递归处理其下的子项
|
||
for subName, subService := range serviceMap {
|
||
processServiceItem(subName, subService, companyLevelCompany, domainFilter, categories, result)
|
||
}
|
||
}
|
||
}
|
||
func handleDomainsInfo(w http.ResponseWriter, domainFilter string) {
|
||
// 如果过滤参数为空字符串,返回空数组
|
||
if domainFilter == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||
return
|
||
}
|
||
|
||
filePath := "./static/domain-info/domains/domain-info.json"
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
http.Error(w, "Failed to read domain info file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("读取域名信息文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 解析JSON
|
||
var domainInfo struct {
|
||
Categories map[string]string `json:"categories"`
|
||
Domains map[string]map[string]interface{} `json:"domains"`
|
||
}
|
||
|
||
if err := json.Unmarshal(data, &domainInfo); err != nil {
|
||
http.Error(w, "Failed to parse domain info file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("解析域名信息文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 转换为所需格式
|
||
var result []map[string]interface{}
|
||
for _, services := range domainInfo.Domains {
|
||
// 获取公司级别的 company 字段
|
||
companyLevelCompany := ""
|
||
if companyData, ok := services["company"].(string); ok {
|
||
companyLevelCompany = companyData
|
||
}
|
||
|
||
// 遍历所有服务(包括嵌套的分组)
|
||
for serviceName, service := range services {
|
||
processServiceItem(serviceName, service, companyLevelCompany, domainFilter, domainInfo.Categories, &result)
|
||
}
|
||
}
|
||
|
||
// 返回 JSON 响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleTrackersInfo 处理跟踪器信息请求,返回名称、类别、url、所属单位/公司
|
||
func handleTrackersInfo(w http.ResponseWriter, trackerFilter string) {
|
||
// 如果过滤参数为空字符串,返回空数组
|
||
if trackerFilter == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||
return
|
||
}
|
||
|
||
filePath := "./static/domain-info/tracker/trackers.json"
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
http.Error(w, "Failed to read trackers file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("读取跟踪器文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 解析JSON
|
||
var trackersInfo struct {
|
||
Categories map[string]string `json:"categories"`
|
||
Trackers map[string]map[string]interface{} `json:"trackers"`
|
||
}
|
||
|
||
if err := json.Unmarshal(data, &trackersInfo); err != nil {
|
||
http.Error(w, "Failed to parse trackers file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("解析跟踪器文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 转换为所需格式
|
||
var result []map[string]interface{}
|
||
for trackerDomain, tracker := range trackersInfo.Trackers {
|
||
// 检查是否需要过滤
|
||
if trackerFilter != "" {
|
||
// 检查跟踪器域名是否包含过滤条件
|
||
if !strings.Contains(trackerDomain, trackerFilter) {
|
||
// 检查名称是否包含过滤条件
|
||
if name, ok := tracker["name"].(string); !ok || !strings.Contains(name, trackerFilter) {
|
||
// 检查URL是否包含过滤条件
|
||
if url, ok := tracker["url"].(string); !ok || !strings.Contains(url, trackerFilter) {
|
||
continue
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
item := map[string]interface{}{
|
||
"name": tracker["name"],
|
||
"url": tracker["url"],
|
||
"company": tracker["companyId"],
|
||
}
|
||
|
||
// 添加类别
|
||
if categoryId, ok := tracker["categoryId"].(float64); ok {
|
||
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
|
||
if category, exists := trackersInfo.Categories[categoryIdStr]; exists {
|
||
item["category"] = category
|
||
}
|
||
}
|
||
|
||
result = append(result, item)
|
||
}
|
||
|
||
// 返回JSON响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleThreatsInfo 处理威胁域名信息请求,返回类型、名称、级别、域名
|
||
func handleThreatsInfo(w http.ResponseWriter, threatFilter string) {
|
||
// 如果过滤参数为空字符串,返回空数组
|
||
if threatFilter == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode([]map[string]string{})
|
||
return
|
||
}
|
||
|
||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
http.Error(w, "Failed to read threats file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("读取威胁域名文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 解析CSV
|
||
reader := csv.NewReader(bytes.NewReader(data))
|
||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||
|
||
// 读取所有记录
|
||
records, err := reader.ReadAll()
|
||
if err != nil {
|
||
http.Error(w, "Failed to parse threats file", http.StatusInternalServerError)
|
||
logger.Error(fmt.Sprintf("解析威胁域名文件失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 转换为所需格式
|
||
var result []map[string]string
|
||
// 跳过标题行
|
||
for i, record := range records {
|
||
if i == 0 {
|
||
continue
|
||
}
|
||
|
||
if len(record) >= 4 {
|
||
// 检查是否需要过滤
|
||
if threatFilter != "" {
|
||
// 检查域名是否包含过滤条件
|
||
if !strings.Contains(record[3], threatFilter) {
|
||
// 检查名称是否包含过滤条件
|
||
if !strings.Contains(record[1], threatFilter) {
|
||
// 检查类型是否包含过滤条件
|
||
if !strings.Contains(record[0], threatFilter) {
|
||
continue
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
item := map[string]string{
|
||
"type": record[0],
|
||
"name": record[1],
|
||
"level": record[2],
|
||
"domain": record[3],
|
||
}
|
||
result = append(result, item)
|
||
}
|
||
}
|
||
|
||
// 返回JSON响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// handleChangePassword 处理修改密码请求
|
||
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 解析请求体
|
||
var changePasswordData struct {
|
||
CurrentPassword string `json:"currentPassword"`
|
||
NewPassword string `json:"newPassword"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&changePasswordData); err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
|
||
return
|
||
}
|
||
|
||
// 验证当前密码
|
||
if changePasswordData.CurrentPassword != s.config.Password {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "当前密码错误"})
|
||
return
|
||
}
|
||
|
||
// 更新密码
|
||
s.config.Password = changePasswordData.NewPassword
|
||
|
||
// 保存配置到文件
|
||
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
|
||
logger.Error("保存配置文件失败", "error", err)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "保存密码失败"})
|
||
return
|
||
}
|
||
|
||
// 返回成功响应
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"})
|
||
logger.Info("密码修改成功")
|
||
}
|
||
|
||
// handleThreatQuery 处理威胁域名查询请求
|
||
// @Summary 查询威胁域名信息
|
||
// @Description 根据传入的域名参数查询威胁数据库,返回威胁类型、名称、风险等级和域名
|
||
// @Tags threat
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param domain query string true "要查询的域名"
|
||
// @Success 200 {string} string "威胁信息,格式:类型,名称,风险等级,域名"
|
||
// @Failure 400 {object} map[string]string "缺少域名参数"
|
||
// @Router /api/threat [get]
|
||
func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
// 获取域名参数
|
||
domain := r.URL.Query().Get("domain")
|
||
if domain == "" {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供 domain 参数"})
|
||
return
|
||
}
|
||
|
||
// 读取威胁数据库 CSV 文件
|
||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
|
||
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
|
||
return
|
||
}
|
||
|
||
// 解析 CSV
|
||
reader := csv.NewReader(bytes.NewReader(data))
|
||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||
|
||
// 读取所有记录
|
||
records, err := reader.ReadAll()
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
|
||
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
|
||
return
|
||
}
|
||
|
||
// 构建威胁域名映射(支持顶级域名匹配)
|
||
threatMap := make(map[string][]string)
|
||
for i, record := range records {
|
||
if i == 0 {
|
||
continue // 跳过标题行
|
||
}
|
||
if len(record) >= 4 {
|
||
threatType := record[0] // 第一列:类型
|
||
threatName := record[1] // 第二列:名称
|
||
riskLevel := record[2] // 第三列:风险等级
|
||
domain := record[3] // 第四列:域名
|
||
threatInfo := []string{threatType, threatName, riskLevel}
|
||
|
||
// 1. 完整域名匹配(所有类型都添加)
|
||
threatMap[domain] = threatInfo
|
||
|
||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||
// 类型判断:钓鱼网站、仿冒网站
|
||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
|
||
// 对于恶意网站,添加子域名匹配规则
|
||
// 例如:sub.example.com -> 添加 .sub.example.com 规则
|
||
// 这样 a.sub.example.com 就会匹配
|
||
topLevelDomain := "." + domain
|
||
// 只有当该顶级域名规则不存在时才添加
|
||
if _, exists := threatMap[topLevelDomain]; !exists {
|
||
threatMap[topLevelDomain] = threatInfo
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 查询单个域名
|
||
var result string
|
||
|
||
// 1. 先检查完整匹配
|
||
if threat, exists := threatMap[domain]; exists {
|
||
result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain)
|
||
} else {
|
||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||
for threatDomain, threatInfo := range threatMap {
|
||
// 只检查以点开头的顶级域名规则
|
||
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
|
||
// 额外验证:确保是完整的域名部分匹配
|
||
prefix := strings.TrimSuffix(domain, threatDomain)
|
||
if len(prefix) > 0 && !strings.HasSuffix(prefix, ".") {
|
||
// 不是完整的子域名部分,跳过
|
||
continue
|
||
}
|
||
|
||
result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
if result == "" {
|
||
// 未找到匹配的威胁信息
|
||
json.NewEncoder(w).Encode(map[string]string{"message": "无"})
|
||
} else {
|
||
// 返回威胁信息
|
||
json.NewEncoder(w).Encode(map[string]string{"data": result})
|
||
}
|
||
}
|
||
|
||
// handleThreatBatch 批量查询威胁域名
|
||
// @Summary 批量查询威胁域名
|
||
// @Description 批量查询多个域名是否是威胁域名
|
||
// @Tags threat
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param domains body []string true "域名列表"
|
||
// @Success 200 {object} map[string]interface{} "批量查询结果"
|
||
// @Router /api/threat/batch [post]
|
||
func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
Domains []string `json:"domains"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "请求格式错误"})
|
||
return
|
||
}
|
||
|
||
// 读取威胁数据库 CSV 文件
|
||
filePath := "./static/domain-info/threats/threats-database.csv"
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
|
||
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
|
||
return
|
||
}
|
||
|
||
// 解析 CSV
|
||
reader := csv.NewReader(bytes.NewReader(data))
|
||
reader.FieldsPerRecord = -1 // 允许不同长度的记录
|
||
|
||
// 读取所有记录
|
||
records, err := reader.ReadAll()
|
||
if err != nil {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
|
||
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
|
||
return
|
||
}
|
||
|
||
// 构建威胁域名映射(支持顶级域名匹配)
|
||
threatMap := make(map[string][]string)
|
||
for i, record := range records {
|
||
if i == 0 {
|
||
continue // 跳过标题行
|
||
}
|
||
if len(record) >= 4 {
|
||
threatType := record[0] // 第一列:类型
|
||
threatName := record[1] // 第二列:名称
|
||
riskLevel := record[2] // 第三列:风险等级
|
||
domain := record[3] // 第四列:域名
|
||
threatInfo := []string{threatType, threatName, riskLevel}
|
||
|
||
// 1. 完整域名匹配(所有类型都添加)
|
||
threatMap[domain] = threatInfo
|
||
|
||
// 2. 只有恶意网站类型才添加子域名匹配规则
|
||
// 类型判断:钓鱼网站、仿冒网站
|
||
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
|
||
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
|
||
// 对于恶意网站,添加子域名匹配规则
|
||
// 例如:sub.example.com -> 添加 .sub.example.com 规则
|
||
// 这样 a.sub.example.com 就会匹配
|
||
topLevelDomain := "." + domain
|
||
// 只有当该顶级域名规则不存在时才添加
|
||
if _, exists := threatMap[topLevelDomain]; !exists {
|
||
threatMap[topLevelDomain] = threatInfo
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 批量查询
|
||
results := make([]map[string]interface{}, 0, len(req.Domains))
|
||
for _, domain := range req.Domains {
|
||
// 1. 先检查完整匹配
|
||
if threat, exists := threatMap[domain]; exists {
|
||
results = append(results, map[string]interface{}{
|
||
"domain": domain,
|
||
"isThreat": true,
|
||
"data": fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain),
|
||
})
|
||
continue
|
||
}
|
||
|
||
// 2. 检查子域名匹配(遍历顶级域名规则)
|
||
matched := false
|
||
for threatDomain, threatInfo := range threatMap {
|
||
// 只检查以点开头的顶级域名规则
|
||
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
|
||
// 验证:确保是有效的子域名匹配
|
||
// 例如:test.example.com 匹配 .example.com ✅
|
||
// notexample.com 不应该匹配 .example.com ❌
|
||
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
|
||
suffixToTrim := threatDomain[1:]
|
||
prefix := strings.TrimSuffix(domain, suffixToTrim)
|
||
|
||
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
|
||
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
|
||
results = append(results, map[string]interface{}{
|
||
"domain": domain,
|
||
"isThreat": true,
|
||
"data": fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain),
|
||
})
|
||
matched = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if !matched {
|
||
results = append(results, map[string]interface{}{
|
||
"domain": domain,
|
||
"isThreat": false,
|
||
})
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"results": results,
|
||
})
|
||
}
|