package http import ( "bytes" "encoding/csv" "encoding/json" "fmt" "net/http" "os" "sort" "strings" "sync" "time" "dns-server/config" "dns-server/dns" "dns-server/domain" "dns-server/gfw" "dns-server/logger" "dns-server/shield" "gopkg.in/ini.v1" "dns-server/threat" "github.com/gorilla/websocket" ) // CacheEntry 缓存条目 type CacheEntry struct { data interface{} // 缓存的数据 timestamp time.Time // 缓存创建时间 hits int // 命中次数(用于 LRU) } // QueryCache 查询结果缓存 type QueryCache struct { data map[string]*CacheEntry // 缓存键 -> 缓存条目 mutex sync.RWMutex // 读写锁 maxSize int // 最大缓存条目数 ttl time.Duration // 缓存有效期 } // StatsCache 统计数据缓存 type StatsCache struct { data map[string]*CacheEntry // 缓存键 -> 缓存条目 mutex sync.RWMutex // 读写锁 maxSize int // 最大缓存条目数 ttl time.Duration // 缓存有效期 lastStats map[string]interface{} // 上次缓存的统计数据,用于增量更新 } // NewQueryCache 创建查询结果缓存 func NewQueryCache(maxSize int, ttl time.Duration) *QueryCache { cache := &QueryCache{ data: make(map[string]*CacheEntry), maxSize: maxSize, ttl: ttl, } // 启动缓存清理协程 go cache.startCleanupLoop() return cache } // NewStatsCache 创建统计数据缓存 func NewStatsCache(maxSize int, ttl time.Duration) *StatsCache { cache := &StatsCache{ data: make(map[string]*CacheEntry), maxSize: maxSize, ttl: ttl, lastStats: make(map[string]interface{}), } // 启动缓存清理协程 go cache.startCleanupLoop() return cache } // Get 获取缓存条目 func (c *QueryCache) Get(key string) (interface{}, bool) { c.mutex.RLock() entry, found := c.data[key] c.mutex.RUnlock() if !found { return nil, false } // 检查是否过期 if time.Since(entry.timestamp) > c.ttl { // 过期,删除 c.mutex.Lock() delete(c.data, key) c.mutex.Unlock() return nil, false } // 更新命中次数 c.mutex.Lock() entry.hits++ c.mutex.Unlock() return entry.data, true } // Set 设置缓存条目 func (c *QueryCache) Set(key string, data interface{}) { c.mutex.Lock() defer c.mutex.Unlock() // 如果缓存已满,删除最少使用的条目 if len(c.data) >= c.maxSize { c.evictLRU() } c.data[key] = &CacheEntry{ data: data, timestamp: time.Now(), hits: 0, } } // Delete 删除缓存条目 func (c *QueryCache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() delete(c.data, key) } // Clear 清空缓存 func (c *QueryCache) Clear() { c.mutex.Lock() defer c.mutex.Unlock() c.data = make(map[string]*CacheEntry) } // evictLRU 淘汰最少使用的条目 func (c *QueryCache) evictLRU() { var lruKey string minHits := int(^uint(0) >> 1) // 最大 int 值 for key, entry := range c.data { if entry.hits < minHits { minHits = entry.hits lruKey = key } } if lruKey != "" { delete(c.data, lruKey) } } // startCleanupLoop 启动清理协程 func (c *QueryCache) startCleanupLoop() { ticker := time.NewTicker(c.ttl / 2) defer ticker.Stop() for range ticker.C { c.cleanupExpired() } } // cleanupExpired 清理过期条目 func (c *QueryCache) cleanupExpired() { c.mutex.Lock() defer c.mutex.Unlock() now := time.Now() for key, entry := range c.data { if now.Sub(entry.timestamp) > c.ttl { delete(c.data, key) } } } // StatsCache 方法 // Get 获取统计数据缓存条目 func (c *StatsCache) Get(key string) (map[string]interface{}, bool) { c.mutex.RLock() entry, found := c.data[key] c.mutex.RUnlock() if !found { return nil, false } // 检查是否过期 if time.Since(entry.timestamp) > c.ttl { // 过期,删除 c.mutex.Lock() delete(c.data, key) c.mutex.Unlock() return nil, false } // 更新命中次数 c.mutex.Lock() entry.hits++ c.mutex.Unlock() if data, ok := entry.data.(map[string]interface{}); ok { return data, true } return nil, false } // Set 设置统计数据缓存条目 func (c *StatsCache) Set(key string, data map[string]interface{}) { c.mutex.Lock() defer c.mutex.Unlock() // 如果缓存已满,删除最少使用的条目 if len(c.data) >= c.maxSize { c.evictLRU() } c.data[key] = &CacheEntry{ data: data, timestamp: time.Now(), hits: 0, } // 保存最后统计数据 c.lastStats = data } // GetLastStats 获取上次缓存的统计数据 func (c *StatsCache) GetLastStats() map[string]interface{} { c.mutex.RLock() defer c.mutex.RUnlock() return c.lastStats } // Delete 删除统计数据缓存条目 func (c *StatsCache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() delete(c.data, key) } // Clear 清空统计数据缓存 func (c *StatsCache) Clear() { c.mutex.Lock() defer c.mutex.Unlock() c.data = make(map[string]*CacheEntry) c.lastStats = make(map[string]interface{}) } // evictLRU 淘汰最少使用的条目 func (c *StatsCache) evictLRU() { var lruKey string minHits := int(^uint(0) >> 1) // 最大 int 值 for key, entry := range c.data { if entry.hits < minHits { minHits = entry.hits lruKey = key } } if lruKey != "" { delete(c.data, lruKey) } } // startCleanupLoop 启动清理协程 func (c *StatsCache) startCleanupLoop() { ticker := time.NewTicker(c.ttl / 2) defer ticker.Stop() for range ticker.C { c.cleanupExpired() } } // cleanupExpired 清理过期条目 func (c *StatsCache) cleanupExpired() { c.mutex.Lock() defer c.mutex.Unlock() now := time.Now() for key, entry := range c.data { if now.Sub(entry.timestamp) > c.ttl { delete(c.data, key) } } } // ClearQueryCache 清除查询缓存 func (s *Server) ClearQueryCache() { if s.queryCache != nil { s.queryCache.Clear() } } // ClearStatsCache 清除统计缓存 func (s *Server) ClearStatsCache() { if s.statsCache != nil { s.statsCache.Clear() } } // ClearAllCache 清除所有缓存 func (s *Server) ClearAllCache() { s.ClearQueryCache() s.ClearStatsCache() } // Server HTTP控制台服务器 type Server struct { globalConfig *config.Config config *config.HTTPConfig dnsServer *dns.Server shieldManager *shield.ShieldManager gfwManager *gfw.GFWListManager server *http.Server // 会话管理相关字段 sessions map[string]time.Time // 会话ID到过期时间的映射 sessionsMutex sync.Mutex // 会话映射的互斥锁 sessionTTL time.Duration // 会话过期时间 // WebSocket 相关字段 upgrader websocket.Upgrader clients map[*websocket.Conn]bool clientsMutex sync.Mutex broadcastChan chan []byte // 查询缓存相关字段 queryCache *QueryCache // 查询结果缓存 statsCache *StatsCache // 统计数据缓存 cacheEnabled bool // 缓存是否启用 cacheTTL time.Duration // 缓存过期时间 cacheMaxSize int // 缓存最大条目数 } // NewServer 创建HTTP服务器实例 func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server { server := &Server{ globalConfig: globalConfig, config: &globalConfig.HTTP, dnsServer: dnsServer, shieldManager: shieldManager, gfwManager: gfwManager, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // 允许所有CORS请求 CheckOrigin: func(r *http.Request) bool { return true }, }, clients: make(map[*websocket.Conn]bool), broadcastChan: make(chan []byte, 100), // 会话管理初始化 sessions: make(map[string]time.Time), sessionTTL: 24 * time.Hour, // 会话有效期 24 小时 // 查询缓存初始化 queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期 statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期 cacheEnabled: true, // 默认启用缓存 cacheTTL: 5 * time.Second, // 默认缓存 5 秒 cacheMaxSize: 100, // 默认最大 100 条 } // 启动广播协程 go server.startBroadcastLoop() // 启动会话清理协程 go server.cleanupSessionsLoop() return server } // Start 启动HTTP服务器 func (s *Server) Start() error { mux := http.NewServeMux() // 登录路由,不需要认证 mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { // 重定向到登录页面HTML http.Redirect(w, r, "/login.html", http.StatusFound) }) // API路由 if s.config.EnableAPI { // 登录API端点,不需要认证 mux.HandleFunc("/api/login", s.handleLogin) // 注销API端点,不需要认证 mux.HandleFunc("/api/logout", s.handleLogout) // 修改密码API端点,需要认证 mux.HandleFunc("/api/change-password", s.loginRequired(s.handleChangePassword)) // 重定向/api到Swagger UI页面 mux.HandleFunc("/api", s.loginRequired(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/api/index.html", http.StatusMovedPermanently) })) // 注册所有API端点,应用登录中间件 mux.HandleFunc("/api/stats", s.loginRequired(s.handleStats)) mux.HandleFunc("/api/shield", s.loginRequired(s.handleShield)) mux.HandleFunc("/api/shield/localrules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method == http.MethodGet { localRules := s.shieldManager.GetLocalRules() json.NewEncoder(w).Encode(localRules) return } http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) })) mux.HandleFunc("/api/shield/remoterules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method == http.MethodGet { remoteRules := s.shieldManager.GetRemoteRules() json.NewEncoder(w).Encode(remoteRules) return } http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) })) mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts)) mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists)) // 传统查询接口(保持向后兼容) mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery)) // RESTful 域名查询接口 mux.HandleFunc("/api/domains/", s.loginRequired(s.handleDomainQuery)) mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus)) mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig)) mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart)) // 添加统计相关接口 mux.HandleFunc("/api/top-blocked", s.loginRequired(s.handleTopBlockedDomains)) mux.HandleFunc("/api/top-resolved", s.loginRequired(s.handleTopResolvedDomains)) mux.HandleFunc("/api/top-clients", s.loginRequired(s.handleTopClients)) mux.HandleFunc("/api/top-domains", s.loginRequired(s.handleTopDomains)) mux.HandleFunc("/api/recent-blocked", s.loginRequired(s.handleRecentBlockedDomains)) mux.HandleFunc("/api/hourly-stats", s.loginRequired(s.handleHourlyStats)) mux.HandleFunc("/api/daily-stats", s.loginRequired(s.handleDailyStats)) mux.HandleFunc("/api/monthly-stats", s.loginRequired(s.handleMonthlyStats)) mux.HandleFunc("/api/query/type", s.loginRequired(s.handleQueryTypeStats)) // 日志统计相关接口 mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats)) mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery)) mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount)) // 域名查询相关接口 mux.HandleFunc("/api/domain/info", s.loginRequired(s.handleDomainInfo)) // 域名信息列表接口 mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList)) // 威胁查询接口 mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery)) // 威胁批量查询接口 mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch)) // 威胁告警接口 mux.HandleFunc("/api/alert", s.loginRequired(s.handleAlert)) // 威胁告警解决接口 mux.HandleFunc("/api/alert/resolve", s.loginRequired(s.handleAlertResolve)) // 威胁域名管理接口 mux.HandleFunc("/api/threat/domain", s.loginRequired(s.handleThreatDomain)) // WebSocket 端点 mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats)) // 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点 apiFileServer := http.FileServer(http.Dir("./static/api")) mux.Handle("/api/", s.loginRequired(http.StripPrefix("/api", apiFileServer).ServeHTTP)) } // 自定义静态文件服务处理器,用于禁用浏览器缓存,放在API路由之后 fileServer := http.FileServer(http.Dir("./static")) // 单独处理login.html,不需要登录 mux.HandleFunc("/login.html", func(w http.ResponseWriter, r *http.Request) { // 添加Cache-Control头,禁用浏览器缓存 w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT") // 直接提供login.html文件 http.ServeFile(w, r, "./static/login.html") }) // Tracker目录静态文件服务 trackerFileServer := http.FileServer(http.Dir("./tracker")) mux.HandleFunc("/tracker/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) { // 添加Cache-Control头,禁用浏览器缓存 w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT") // 使用StripPrefix处理路径 http.StripPrefix("/tracker", trackerFileServer).ServeHTTP(w, r) })) // 其他静态文件需要登录 mux.HandleFunc("/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) { // 添加Cache-Control头,禁用浏览器缓存 w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT") // 使用StripPrefix处理路径 http.StripPrefix("/", fileServer).ServeHTTP(w, r) })) s.server = &http.Server{ Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), Handler: mux, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port)) return s.server.ListenAndServe() } // Stop 停止HTTP服务器 func (s *Server) Stop() { if s.server != nil { s.server.Close() } logger.Info("HTTP控制台服务器已停止") } // handleStats 处理统计信息请求 // @Summary 获取系统统计信息 // @Description 获取DNS服务器和Shield的统计信息 // @Tags stats // @Accept json // @Produce json // @Success 200 {object} map[string]interface{} "统计信息" // @Failure 500 {object} map[string]string "服务器内部错误" // @Router /api/stats [get] func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } dnsStats := s.dnsServer.GetStats() shieldStats := s.shieldManager.GetStats() // 获取最常用查询类型(如果有) topQueryType := "-" maxCount := int64(0) if len(dnsStats.QueryTypes) > 0 { for queryType, count := range dnsStats.QueryTypes { if count > maxCount { maxCount = count topQueryType = queryType } } } // 获取活跃来源IP数量 activeIPCount := len(dnsStats.SourceIPs) // 格式化平均响应时间为两位小数 formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100 // 计算DNSSEC使用率 dnssecUsage := float64(0) if dnsStats.Queries > 0 { dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100 } // 构建响应数据,确保所有字段都反映服务器的真实状态 stats := map[string]interface{}{ "dns": map[string]interface{}{ "Queries": dnsStats.Queries, "Blocked": dnsStats.Blocked, "Allowed": dnsStats.Allowed, "Errors": dnsStats.Errors, "LastQuery": dnsStats.LastQuery, "AvgResponseTime": formattedResponseTime, "TotalResponseTime": dnsStats.TotalResponseTime, "QueryTypes": dnsStats.QueryTypes, "SourceIPs": dnsStats.SourceIPs, "CpuUsage": dnsStats.CpuUsage, "DNSSECQueries": dnsStats.DNSSECQueries, "DNSSECSuccess": dnsStats.DNSSECSuccess, "DNSSECFailed": dnsStats.DNSSECFailed, "DNSSECEnabled": dnsStats.DNSSECEnabled, }, "shield": shieldStats, "topQueryType": topQueryType, "activeIPs": activeIPCount, "avgResponseTime": formattedResponseTime, "cpuUsage": dnsStats.CpuUsage, "dnssecEnabled": dnsStats.DNSSECEnabled, "dnssecQueries": dnsStats.DNSSECQueries, "dnssecSuccess": dnsStats.DNSSECSuccess, "dnssecFailed": dnsStats.DNSSECFailed, "dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数 "time": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(stats) } // WebSocket相关方法 // handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据 func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) { // 升级HTTP连接为WebSocket conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err)) return } defer conn.Close() // 将新客户端添加到客户端列表 s.clientsMutex.Lock() s.clients[conn] = true clientCount := len(s.clients) s.clientsMutex.Unlock() logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount)) // 发送初始数据 if err := s.sendInitialStats(conn); err != nil { logger.Error(fmt.Sprintf("发送初始数据失败: %v", err)) return } // 定期发送更新数据 ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化 defer ticker.Stop() // 最后一次发送的数据快照,用于检测变化 var lastStats map[string]interface{} // 保持连接并定期发送数据 for { select { case <-ticker.C: // 获取最新统计数据 currentStats := s.buildStatsData() // 检查数据是否有变化 if !s.areStatsEqual(lastStats, currentStats) { // 数据有变化,发送更新 data, err := json.Marshal(map[string]interface{}{ "type": "stats_update", "data": currentStats, "time": time.Now(), }) if err != nil { logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err)) continue } if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err)) return } // 更新最后发送的数据 lastStats = currentStats } case <-r.Context().Done(): // 客户端断开连接 s.clientsMutex.Lock() delete(s.clients, conn) clientCount := len(s.clients) s.clientsMutex.Unlock() logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount)) return } } } // sendInitialStats 发送初始统计数据 func (s *Server) sendInitialStats(conn *websocket.Conn) error { stats := s.buildStatsData() data, err := json.Marshal(map[string]interface{}{ "type": "initial_data", "data": stats, "time": time.Now(), }) if err != nil { return err } return conn.WriteMessage(websocket.TextMessage, data) } // buildStatsData 构建统计数据 func (s *Server) buildStatsData() map[string]interface{} { dnsStats := s.dnsServer.GetStats() shieldStats := s.shieldManager.GetStats() // 获取最常用查询类型 topQueryType := "-" maxCount := int64(0) if len(dnsStats.QueryTypes) > 0 { for queryType, count := range dnsStats.QueryTypes { if count > maxCount { maxCount = count topQueryType = queryType } } } // 获取活跃来源IP数量 activeIPCount := len(dnsStats.SourceIPs) // 格式化平均响应时间 formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100 // 计算DNSSEC使用率 dnssecUsage := float64(0) if dnsStats.Queries > 0 { dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100 } return map[string]interface{}{ "dns": map[string]interface{}{ "Queries": dnsStats.Queries, "Blocked": dnsStats.Blocked, "Allowed": dnsStats.Allowed, "Errors": dnsStats.Errors, "LastQuery": dnsStats.LastQuery, "AvgResponseTime": formattedResponseTime, "TotalResponseTime": dnsStats.TotalResponseTime, "QueryTypes": dnsStats.QueryTypes, "SourceIPs": dnsStats.SourceIPs, "CpuUsage": dnsStats.CpuUsage, "DNSSECQueries": dnsStats.DNSSECQueries, "DNSSECSuccess": dnsStats.DNSSECSuccess, "DNSSECFailed": dnsStats.DNSSECFailed, "DNSSECEnabled": dnsStats.DNSSECEnabled, }, "shield": shieldStats, "topQueryType": topQueryType, "activeIPs": activeIPCount, "avgResponseTime": formattedResponseTime, "cpuUsage": dnsStats.CpuUsage, "dnssecEnabled": dnsStats.DNSSECEnabled, "dnssecQueries": dnsStats.DNSSECQueries, "dnssecSuccess": dnsStats.DNSSECSuccess, "dnssecFailed": dnsStats.DNSSECFailed, "dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数 } } // areStatsEqual 检查两次统计数据是否相等(用于检测变化) func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool { if stats1 == nil || stats2 == nil { return false } // 只比较关键数值,避免频繁更新 if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 { if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 { // 检查主要计数器 if dns1["Queries"] != dns2["Queries"] || dns1["Blocked"] != dns2["Blocked"] || dns1["Allowed"] != dns2["Allowed"] || dns1["Errors"] != dns2["Errors"] || dns1["AvgResponseTime"] != dns2["AvgResponseTime"] { return false } } } return true } // startBroadcastLoop 启动广播循环 func (s *Server) startBroadcastLoop() { for message := range s.broadcastChan { s.clientsMutex.Lock() for client := range s.clients { if err := client.WriteMessage(websocket.TextMessage, message); err != nil { logger.Error(fmt.Sprintf("广播消息失败: %v", err)) client.Close() delete(s.clients, client) } } s.clientsMutex.Unlock() } } // cleanupSessionsLoop 定期清理过期会话 func (s *Server) cleanupSessionsLoop() { for { time.Sleep(1 * time.Hour) // 每小时清理一次 s.sessionsMutex.Lock() now := time.Now() for sessionID, expiryTime := range s.sessions { if now.After(expiryTime) { delete(s.sessions, sessionID) } } s.sessionsMutex.Unlock() } } // isAuthenticated 检查用户是否已认证 func (s *Server) isAuthenticated(r *http.Request) bool { // 从Cookie中获取会话ID cookie, err := r.Cookie("session_id") if err != nil { return false } sessionID := cookie.Value s.sessionsMutex.Lock() defer s.sessionsMutex.Unlock() // 检查会话是否存在且未过期 expiryTime, exists := s.sessions[sessionID] if !exists { return false } if time.Now().After(expiryTime) { // 会话已过期,删除它 delete(s.sessions, sessionID) return false } // 延长会话有效期 s.sessions[sessionID] = time.Now().Add(s.sessionTTL) return true } // loginRequired 登录中间件 func (s *Server) loginRequired(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 检查是否为登录页面或登录API,允许直接访问 if r.URL.Path == "/login" || r.URL.Path == "/api/login" { next.ServeHTTP(w, r) return } // 检查是否已认证 if !s.isAuthenticated(r) { // 如果是API请求,返回401错误 if strings.HasPrefix(r.URL.Path, "/api/") || strings.HasPrefix(r.URL.Path, "/ws/") { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "未授权访问"}) return } // 否则重定向到登录页面 http.Redirect(w, r, "/login", http.StatusFound) return } // 已认证,继续处理请求 next.ServeHTTP(w, r) } } // handleTopBlockedDomains 处理TOP屏蔽域名请求 func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 返回最近30天的所有域名,设置合理上限50 domains := s.dnsServer.GetTopBlockedDomains(50) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopResolvedDomains 处理获取最常解析的域名请求 func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetTopResolvedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleRecentBlockedDomains 处理最近屏蔽域名请求 func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetRecentBlockedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "time": time.Unix(domain.LastSeen, 0).Format("15:04:05"), } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleHourlyStats 处理24小时统计请求 func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } hourlyStats := s.dnsServer.GetHourlyStats() // 生成最近24小时的数据 now := time.Now() labels := make([]string, 24) data := make([]int64, 24) for i := 23; i >= 0; i-- { hour := now.Add(time.Duration(-i) * time.Hour) hourKey := hour.Format("2006-01-02-15") labels[23-i] = hour.Format("15:00") data[23-i] = hourlyStats[hourKey] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleDailyStats 处理每日统计数据请求 func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取每日统计数据 dailyStats := s.dnsServer.GetDailyStats() // 生成过去7天的时间标签 labels := make([]string, 7) data := make([]int64, 7) now := time.Now() for i := 6; i >= 0; i-- { t := now.AddDate(0, 0, -i) key := t.Format("2006-01-02") labels[6-i] = t.Format("01-02") data[6-i] = dailyStats[key] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleMonthlyStats 处理每月统计数据请求 func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取每日统计数据(用于30天视图) dailyStats := s.dnsServer.GetDailyStats() // 生成过去30天的时间标签 labels := make([]string, 30) data := make([]int64, 30) now := time.Now() for i := 29; i >= 0; i-- { t := now.AddDate(0, 0, -i) key := t.Format("2006-01-02") labels[29-i] = t.Format("01-02") data[29-i] = dailyStats[key] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleQueryTypeStats 处理查询类型统计请求 func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取DNS统计数据 dnsStats := s.dnsServer.GetStats() // 转换为前端需要的格式 result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes)) for queryType, count := range dnsStats.QueryTypes { result = append(result, map[string]interface{}{ "type": queryType, "count": count, }) } // 按计数降序排序 sort.Slice(result, func(i, j int) bool { return result[i]["count"].(int64) > result[j]["count"].(int64) }) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopClients 处理TOP客户端请求 func (s *Server) handleTopClients(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取TOP客户端列表 clients := s.dnsServer.GetTopClients(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(clients)) for i, client := range clients { result[i] = map[string]interface{}{ "ip": client.IP, "count": client.Count, "lastSeen": client.LastSeen, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopDomains 处理TOP域名请求 func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取TOP被屏蔽域名,返回最近30天的数据,设置合理上限50 blockedDomains := s.dnsServer.GetTopBlockedDomains(50) // 获取TOP已解析域名,返回最近30天的数据,设置合理上限50 resolvedDomains := s.dnsServer.GetTopResolvedDomains(50) // 合并并去重域名统计 domainMap := make(map[string]int64) dnssecStatusMap := make(map[string]bool) for _, domain := range blockedDomains { domainMap[domain.Domain] += domain.Count dnssecStatusMap[domain.Domain] = domain.DNSSEC } for _, domain := range resolvedDomains { domainMap[domain.Domain] += domain.Count dnssecStatusMap[domain.Domain] = domain.DNSSEC } // 转换为切片并排序 domainList := make([]map[string]interface{}, 0, len(domainMap)) for domain, count := range domainMap { dnssec, hasDNSSEC := dnssecStatusMap[domain] domainList = append(domainList, map[string]interface{}{ "domain": domain, "count": count, "dnssec": hasDNSSEC && dnssec, }) } // 按计数降序排序 sort.Slice(domainList, func(i, j int) bool { return domainList[i]["count"].(int64) > domainList[j]["count"].(int64) }) // 返回所有合并后的域名,设置合理上限50 if len(domainList) > 50 { domainList = domainList[:50] } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(domainList) } // handleShield 处理Shield相关操作 // @Summary 管理Shield配置 // @Description 获取或更新Shield的配置信息 // @Tags shield // @Accept json // @Produce json // @Param config body map[string]interface{} false "Shield配置信息" // @Success 200 {object} map[string]interface{} "配置信息" // @Failure 400 {object} map[string]string "请求参数错误" // @Failure 500 {object} map[string]string "服务器内部错误" // @Router /api/shield [get] // @Router /api/shield [post] func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // 默认处理逻辑 switch r.Method { case http.MethodGet: // 检查是否需要返回完整规则列表 if r.URL.Query().Get("all") == "true" { // 返回完整规则数据 rules := s.shieldManager.GetRules() json.NewEncoder(w).Encode(rules) return } // 获取规则统计信息 stats := s.shieldManager.GetStats() shieldInfo := map[string]interface{}{ "updateInterval": s.globalConfig.Shield.UpdateInterval, "blockMethod": s.globalConfig.Shield.BlockMethod, "blacklistCount": len(s.globalConfig.Shield.Blacklists), "domainRulesCount": stats["domainRules"], "domainExceptionsCount": stats["domainExceptions"], "regexRulesCount": stats["regexRules"], "regexExceptionsCount": stats["regexExceptions"], "hostsRulesCount": stats["hostsRules"], } json.NewEncoder(w).Encode(shieldInfo) return } switch r.Method { case http.MethodPost: // 添加屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.AddRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // 清空DNS缓存,使新规则立即生效 s.dnsServer.DnsCache.Clear() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return case http.MethodDelete: // 删除屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // 清空DNS缓存,使规则变更立即生效 s.dnsServer.DnsCache.Clear() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return case http.MethodPut: // 重新加载规则 if err := s.shieldManager.LoadRules(); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"}) return default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } } // handleShieldBlacklists 处理黑名单相关操作 // @Summary 管理黑名单 // @Description 处理黑名单的CRUD操作,包括获取列表、添加、更新和删除黑名单 // @Tags shield // @Accept json // @Produce json // @Param name path string false "黑名单名称(用于删除操作)" // @Param blacklist body map[string]interface{} false "黑名单信息(用于添加/更新操作)" // @Success 200 {object} map[string]interface{} "操作成功" // @Failure 400 {object} map[string]string "请求参数错误" // @Failure 404 {object} map[string]string "黑名单不存在" // @Failure 500 {object} map[string]string "服务器内部错误" // @Router /api/shield/blacklists [get] // @Router /api/shield/blacklists [post] // @Router /api/shield/blacklists [put] // @Router /api/shield/blacklists/{name} [delete] func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // 处理更新单个黑名单 if strings.Contains(r.URL.Path, "/update") { if r.Method == http.MethodPost { // 提取黑名单URL或Name parts := strings.Split(r.URL.Path, "/") var targetURLOrName string for i, part := range parts { if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" { targetURLOrName = parts[i+1] break } } if targetURLOrName == "" { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单标识不能为空"}) return } // 获取黑名单列表 blacklists := s.shieldManager.GetBlacklists() var targetIndex = -1 for i, list := range blacklists { if list.URL == targetURLOrName || list.Name == targetURLOrName { targetIndex = i break } } if targetIndex == -1 { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单不存在"}) return } // 更新时间戳 blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339) // 保存更新后的黑名单列表 s.shieldManager.UpdateBlacklist(blacklists) // 更新全局配置中的黑名单 s.globalConfig.Shield.Blacklists = blacklists // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil { logger.Error("保存配置文件失败", "error", err) json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"}) return } // 重新加载规则以获取最新的远程规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return } } // 处理删除黑名单 parts := strings.Split(r.URL.Path, "/") if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete { id := parts[4] blacklists := s.shieldManager.GetBlacklists() var newBlacklists []config.BlacklistEntry for _, list := range blacklists { if list.URL != id && list.Name != id { newBlacklists = append(newBlacklists, list) } } s.shieldManager.UpdateBlacklist(newBlacklists) // 更新全局配置中的黑名单 s.globalConfig.Shield.Blacklists = newBlacklists // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil { logger.Error("保存配置文件失败", "error", err) json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"}) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return } switch r.Method { case http.MethodGet: // 获取远程黑名单列表 blacklists := s.shieldManager.GetBlacklists() json.NewEncoder(w).Encode(blacklists) case http.MethodPost: // 添加远程黑名单 var req struct { Name string `json:"name"` URL string `json:"url"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"}) return } if req.Name == "" || req.URL == "" { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "名称和URL不能为空"}) return } // 获取现有黑名单 blacklists := s.shieldManager.GetBlacklists() // 检查是否已存在 for _, list := range blacklists { if list.URL == req.URL { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单URL已存在"}) return } } // 检查URL是否存在且可访问 if !checkURLExists(req.URL) { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "URL不存在或无法访问"}) return } // 添加新黑名单 newEntry := config.BlacklistEntry{ Name: req.Name, URL: req.URL, Enabled: true, } blacklists = append(blacklists, newEntry) s.shieldManager.UpdateBlacklist(blacklists) // 更新全局配置中的黑名单 s.globalConfig.Shield.Blacklists = blacklists // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil { logger.Error("保存配置文件失败", "error", err) json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"}) return } // 重新加载规则以获取新添加的远程规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodPut: // 更新黑名单列表(包括启用/禁用状态) var updatedBlacklists []struct { Name string `json:"Name" json:"name"` URL string `json:"URL" json:"url"` Enabled bool `json:"Enabled" json:"enabled"` LastUpdateTime string `json:"LastUpdateTime" json:"lastUpdateTime"` } if err := json.NewDecoder(r.Body).Decode(&updatedBlacklists); err != nil { json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"}) return } // 转换为config.BlacklistEntry类型 var newBlacklists []config.BlacklistEntry for _, entry := range updatedBlacklists { newBlacklists = append(newBlacklists, config.BlacklistEntry{ Name: entry.Name, URL: entry.URL, Enabled: entry.Enabled, LastUpdateTime: entry.LastUpdateTime, }) } // 更新黑名单 s.shieldManager.UpdateBlacklist(newBlacklists) // 更新全局配置中的黑名单 s.globalConfig.Shield.Blacklists = newBlacklists // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil { logger.Error("保存配置文件失败", "error", err) json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"}) return } // 重新加载规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) default: json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "Method not allowed"}) } } // handleShieldHosts 处理hosts管理请求 func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: // 添加hosts条目 var req struct { IP string `json:"ip"` Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.IP == "" || req.Domain == "" { http.Error(w, "IP and Domain are required", http.StatusBadRequest) return } if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodDelete: // 删除hosts条目 var req struct { Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodGet: // 获取hosts条目列表 hosts := s.shieldManager.GetAllHosts() hostsCount := s.shieldManager.GetHostsCount() // 转换为数组格式,便于前端展示 hostsList := make([]map[string]string, 0, len(hosts)) for domain, ip := range hosts { hostsList = append(hostsList, map[string]string{ "domain": domain, "ip": ip, }) } json.NewEncoder(w).Encode(map[string]interface{}{ "hosts": hostsList, "hostsCount": hostsCount, }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleQuery 处理DNS查询请求(传统接口,保持向后兼容) func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domain := r.URL.Query().Get("domain") if domain == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"}) return } // 获取域名屏蔽的详细信息 blockDetails := s.shieldManager.CheckDomainBlockDetails(domain) // 添加时间戳 blockDetails["timestamp"] = time.Now() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(blockDetails) } // handleDomainQuery 处理RESTful风格的域名查询请求 func (s *Server) handleDomainQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 从URL路径中提取域名参数 // 路径格式: /api/domains/{domain} path := r.URL.Path parts := strings.Split(path, "/") if len(parts) < 4 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"}) return } domain := parts[3] if domain == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"}) return } // 获取域名屏蔽的详细信息 blockDetails := s.shieldManager.CheckDomainBlockDetails(domain) // 构建RESTful风格的响应 response := map[string]interface{}{ "domain": domain, "status": blockDetails["blocked"], "timestamp": time.Now(), "details": blockDetails, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // handleStatus 处理系统状态请求 func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } stats := s.dnsServer.GetStats() // 使用服务器的实际启动时间计算准确的运行时间 serverStartTime := s.dnsServer.GetStartTime() uptime := time.Since(serverStartTime) // 构建包含所有真实服务器统计数据的响应 status := map[string]interface{}{ "status": "running", "queries": stats.Queries, "blocked": stats.Blocked, "allowed": stats.Allowed, "errors": stats.Errors, "lastQuery": stats.LastQuery, "avgResponseTime": stats.AvgResponseTime, "activeIPs": len(stats.SourceIPs), "startTime": serverStartTime, "uptime": uptime.Milliseconds(), // 转换为毫秒数,方便前端处理 "cpuUsage": stats.CpuUsage, "timestamp": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } // saveConfigToFile 保存配置到文件 func saveConfigToFile(config *config.Config, filePath string) error { // 创建新的INI文件 cfg := ini.Empty() // DNS配置 dnsSection := cfg.Section("dns") dnsSection.Key("port").SetValue(fmt.Sprintf("%d", config.DNS.Port)) dnsSection.Key("upstreamDNS").SetValue(strings.Join(config.DNS.UpstreamDNS, ", ")) dnsSection.Key("dnssecUpstreamDNS").SetValue(strings.Join(config.DNS.DNSSECUpstreamDNS, ", ")) dnsSection.Key("saveInterval").SetValue(fmt.Sprintf("%d", config.DNS.SaveInterval)) dnsSection.Key("cacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.CacheTTL)) dnsSection.Key("enableDNSSEC").SetValue(fmt.Sprintf("%t", config.DNS.EnableDNSSEC)) dnsSection.Key("queryMode").SetValue(config.DNS.QueryMode) dnsSection.Key("queryTimeout").SetValue(fmt.Sprintf("%d", config.DNS.QueryTimeout)) dnsSection.Key("enableFastReturn").SetValue(fmt.Sprintf("%t", config.DNS.EnableFastReturn)) dnsSection.Key("noDNSSECDomains").SetValue(strings.Join(config.DNS.NoDNSSECDomains, ", ")) dnsSection.Key("enableIPv6").SetValue(fmt.Sprintf("%t", config.DNS.EnableIPv6)) dnsSection.Key("cacheMode").SetValue(config.DNS.CacheMode) dnsSection.Key("cacheSize").SetValue(fmt.Sprintf("%d", config.DNS.CacheSize)) dnsSection.Key("maxCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MaxCacheTTL)) dnsSection.Key("minCacheTTL").SetValue(fmt.Sprintf("%d", config.DNS.MinCacheTTL)) // 域名特定DNS服务器配置 for domain, servers := range config.DNS.DomainSpecificDNS { dnsSection.Key(fmt.Sprintf("domain_%s", domain)).SetValue(strings.Join(servers, ", ")) } // HTTP配置 httpSection := cfg.Section("http") httpSection.Key("port").SetValue(fmt.Sprintf("%d", config.HTTP.Port)) httpSection.Key("host").SetValue(config.HTTP.Host) httpSection.Key("enableAPI").SetValue(fmt.Sprintf("%t", config.HTTP.EnableAPI)) httpSection.Key("username").SetValue(config.HTTP.Username) httpSection.Key("password").SetValue(config.HTTP.Password) // Shield配置 shieldSection := cfg.Section("shield") shieldSection.Key("updateInterval").SetValue(fmt.Sprintf("%d", config.Shield.UpdateInterval)) shieldSection.Key("blockMethod").SetValue(config.Shield.BlockMethod) shieldSection.Key("customBlockIP").SetValue(config.Shield.CustomBlockIP) shieldSection.Key("statsSaveInterval").SetValue(fmt.Sprintf("%d", config.Shield.StatsSaveInterval)) // 黑名单配置 for _, bl := range config.Shield.Blacklists { shieldSection.Key(fmt.Sprintf("blacklist_%s", bl.Name)).SetValue(fmt.Sprintf("%s,%t", bl.URL, bl.Enabled)) } // GFWList配置 gfwListSection := cfg.Section("gfwList") gfwListSection.Key("ip").SetValue(config.GFWList.IP) gfwListSection.Key("content").SetValue(config.GFWList.Content) gfwListSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.GFWList.Enabled)) // Log配置 logSection := cfg.Section("log") logSection.Key("level").SetValue(config.Log.Level) logSection.Key("maxSize").SetValue(fmt.Sprintf("%d", config.Log.MaxSize)) logSection.Key("maxBackups").SetValue(fmt.Sprintf("%d", config.Log.MaxBackups)) logSection.Key("maxAge").SetValue(fmt.Sprintf("%d", config.Log.MaxAge)) // Threat配置 threatSection := cfg.Section("threat") threatSection.Key("enabled").SetValue(fmt.Sprintf("%t", config.Threat.Enabled)) threatSection.Key("queryRateThreshold").SetValue(fmt.Sprintf("%d", config.Threat.QueryRateThreshold)) threatSection.Key("nxDomainThreshold").SetValue(fmt.Sprintf("%d", config.Threat.NXDomainThreshold)) threatSection.Key("maxDomainLength").SetValue(fmt.Sprintf("%d", config.Threat.MaxDomainLength)) threatSection.Key("suspiciousPatterns").SetValue(strings.Join(config.Threat.SuspiciousPatterns, ",")) threatSection.Key("unusualQueryTypes").SetValue(strings.Join(config.Threat.UnusualQueryTypes, ",")) threatSection.Key("alertRetentionDays").SetValue(fmt.Sprintf("%d", config.Threat.AlertRetentionDays)) threatSection.Key("threatDatabasePath").SetValue(config.Threat.ThreatDatabasePath) // 保存到文件 return cfg.SaveTo(filePath) } // handleConfig 处理配置请求 func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: // 每次从配置文件重新读取最新配置 cfg, err := config.LoadConfig("config.ini") if err != nil { logger.Error("加载配置文件失败", "error", err) // 如果加载失败,返回内存中的配置 cfg = s.globalConfig } // 返回当前配置(包括黑名单配置) // 注意:key 名必须与前端期望的一致 config := map[string]interface{}{ "Shield": map[string]interface{}{ "blockMethod": cfg.Shield.BlockMethod, "customBlockIP": cfg.Shield.CustomBlockIP, "blacklists": cfg.Shield.Blacklists, "updateInterval": cfg.Shield.UpdateInterval, "statsSaveInterval": cfg.Shield.StatsSaveInterval, }, "GFWList": map[string]interface{}{ "ip": cfg.GFWList.IP, "content": cfg.GFWList.Content, "enabled": cfg.GFWList.Enabled, }, "DNSServer": map[string]interface{}{ "port": cfg.DNS.Port, "QueryMode": cfg.DNS.QueryMode, "UpstreamServers": cfg.DNS.UpstreamDNS, "DNSSECUpstreamServers": cfg.DNS.DNSSECUpstreamDNS, "saveInterval": cfg.DNS.SaveInterval, "queryTimeout": cfg.DNS.QueryTimeout, "enableIPv6": cfg.DNS.EnableIPv6, "enableDNSSEC": cfg.DNS.EnableDNSSEC, "enableFastReturn": cfg.DNS.EnableFastReturn, "noDNSSECDomains": cfg.DNS.NoDNSSECDomains, "CacheMode": cfg.DNS.CacheMode, "CacheSize": cfg.DNS.CacheSize, "MaxCacheTTL": cfg.DNS.MaxCacheTTL, "MinCacheTTL": cfg.DNS.MinCacheTTL, "domainSpecificDNS": cfg.DNS.DomainSpecificDNS, }, "HTTPServer": map[string]interface{}{ "port": cfg.HTTP.Port, "host": cfg.HTTP.Host, "enableAPI": cfg.HTTP.EnableAPI, "username": cfg.HTTP.Username, "password": cfg.HTTP.Password, }, } json.NewEncoder(w).Encode(config) case http.MethodPost: // 更新配置 var req struct { DNSServer struct { Port int `json:"port"` QueryMode string `json:"queryMode"` UpstreamServers []string `json:"upstreamServers"` DnssecUpstreamServers []string `json:"dnssecUpstreamServers"` Timeout int `json:"timeout"` SaveInterval int `json:"saveInterval"` EnableIPv6 bool `json:"enableIPv6"` EnableDNSSEC bool `json:"enableDNSSEC"` EnableFastReturn *bool `json:"enableFastReturn"` NoDNSSECDomains []string `json:"noDNSSECDomains"` CacheMode string `json:"cacheMode"` CacheSize int `json:"cacheSize"` MaxCacheTTL int `json:"maxCacheTTL"` MinCacheTTL int `json:"minCacheTTL"` DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"` } `json:"dnsserver"` HTTPServer struct { Port int `json:"port"` Host string `json:"host"` EnableAPI bool `json:"enableAPI"` Username string `json:"username"` Password string `json:"password"` } `json:"httpserver"` Shield struct { BlockMethod string `json:"blockMethod"` CustomBlockIP string `json:"customBlockIP"` Blacklists []config.BlacklistEntry `json:"blacklists"` UpdateInterval int `json:"updateInterval"` StatsSaveInterval int `json:"statsSaveInterval"` } `json:"shield"` GFWList struct { IP string `json:"ip"` Content string `json:"content"` Enabled bool `json:"enabled"` } `json:"gfwList"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "无效的请求体", http.StatusBadRequest) return } // 更新DNS配置 if req.DNSServer.Port > 0 { s.globalConfig.DNS.Port = req.DNSServer.Port } if len(req.DNSServer.UpstreamServers) > 0 { s.globalConfig.DNS.UpstreamDNS = req.DNSServer.UpstreamServers } if len(req.DNSServer.DnssecUpstreamServers) > 0 { s.globalConfig.DNS.DNSSECUpstreamDNS = req.DNSServer.DnssecUpstreamServers } if req.DNSServer.SaveInterval > 0 { s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval } if req.DNSServer.Timeout > 0 { s.globalConfig.DNS.QueryTimeout = req.DNSServer.Timeout } s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6 s.globalConfig.DNS.EnableDNSSEC = req.DNSServer.EnableDNSSEC // 更新查询模式 if req.DNSServer.QueryMode != "" { s.globalConfig.DNS.QueryMode = req.DNSServer.QueryMode } // 更新缓存配置 if req.DNSServer.CacheMode != "" { s.globalConfig.DNS.CacheMode = req.DNSServer.CacheMode } if req.DNSServer.CacheSize > 0 { s.globalConfig.DNS.CacheSize = req.DNSServer.CacheSize } if req.DNSServer.MaxCacheTTL > 0 { s.globalConfig.DNS.MaxCacheTTL = req.DNSServer.MaxCacheTTL } if req.DNSServer.MinCacheTTL > 0 { s.globalConfig.DNS.MinCacheTTL = req.DNSServer.MinCacheTTL } // 更新enableFastReturn if req.DNSServer.EnableFastReturn != nil { s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn } // 更新noDNSSECDomains if len(req.DNSServer.NoDNSSECDomains) > 0 { s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains } // 更新domainSpecificDNS if req.DNSServer.DomainSpecificDNS != nil { s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS } // 更新HTTP配置 if req.HTTPServer.Port > 0 { s.globalConfig.HTTP.Port = req.HTTPServer.Port } if req.HTTPServer.Host != "" { s.globalConfig.HTTP.Host = req.HTTPServer.Host } s.globalConfig.HTTP.EnableAPI = req.HTTPServer.EnableAPI if req.HTTPServer.Username != "" { s.globalConfig.HTTP.Username = req.HTTPServer.Username } if req.HTTPServer.Password != "" { s.globalConfig.HTTP.Password = req.HTTPServer.Password } // 更新屏蔽配置 if req.Shield.BlockMethod != "" { // 验证屏蔽方法是否有效 validMethods := map[string]bool{ "NXDOMAIN": true, "refused": true, "emptyIP": true, "customIP": true, } if !validMethods[req.Shield.BlockMethod] { http.Error(w, "无效的屏蔽方法", http.StatusBadRequest) return } s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod // 如果选择了customIP,验证IP地址 if req.Shield.BlockMethod == "customIP" { if req.Shield.CustomBlockIP == "" { http.Error(w, "自定义IP不能为空", http.StatusBadRequest) return } // 简单的IP地址验证 if !isValidIP(req.Shield.CustomBlockIP) { http.Error(w, "无效的IP地址", http.StatusBadRequest) return } } } if req.Shield.CustomBlockIP != "" { s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP } // 更新更新间隔 if req.Shield.UpdateInterval > 0 { s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval // 重新启动自动更新 s.shieldManager.StopAutoUpdate() s.shieldManager.StartAutoUpdate() } // 更新统计保存间隔 if req.Shield.StatsSaveInterval > 0 { s.globalConfig.Shield.StatsSaveInterval = req.Shield.StatsSaveInterval } // 更新GFWList配置 s.globalConfig.GFWList.IP = req.GFWList.IP s.globalConfig.GFWList.Content = req.GFWList.Content s.globalConfig.GFWList.Enabled = req.GFWList.Enabled // 重新加载GFWList规则 if s.gfwManager != nil { if err := s.gfwManager.LoadRules(); err != nil { logger.Error("重新加载GFWList规则失败", "error", err) } } // 更新黑名单配置 if req.Shield.Blacklists != nil { // 验证黑名单配置 for i, bl := range req.Shield.Blacklists { if bl.URL == "" { http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest) return } if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") { http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest) return } } s.globalConfig.Shield.Blacklists = req.Shield.Blacklists s.shieldManager.UpdateBlacklist(req.Shield.Blacklists) // 重新加载规则 if err := s.shieldManager.LoadRules(); err != nil { logger.Error("重新加载规则失败", "error", err) } } // 更新现有的DNSCache实例配置 // 最大和最小TTL(秒) maxCacheTTL := time.Duration(s.globalConfig.DNS.MaxCacheTTL) * time.Second minCacheTTL := time.Duration(s.globalConfig.DNS.MinCacheTTL) * time.Second // 最大缓存大小(字节) maxCacheSize := int64(s.globalConfig.DNS.CacheSize) * 1024 * 1024 // 更新缓存配置 s.dnsServer.DnsCache.SetMaxCacheTTL(maxCacheTTL) s.dnsServer.DnsCache.SetMinCacheTTL(minCacheTTL) s.dnsServer.DnsCache.SetCacheMode(s.globalConfig.DNS.CacheMode) s.dnsServer.DnsCache.SetMaxCacheSize(maxCacheSize) // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "./config.ini"); err != nil { logger.Error("保存配置到文件失败", "error", err) // 不返回错误,只记录日志,因为配置已经在内存中更新成功 } // 返回成功响应 json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "message": "配置已更新", }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // isValidIP 简单验证IP地址格式 func isValidIP(ip string) bool { // 简单的IPv4地址验证 parts := strings.Split(ip, ".") if len(parts) != 4 { return false } for _, part := range parts { // 检查是否为数字 for _, char := range part { if char < '0' || char > '9' { return false } } // 检查数字范围 var num int if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 { return false } } return true } // checkURLExists 检查URL是否存在且可访问 func checkURLExists(url string) bool { // 创建一个带有超时的HTTP客户端 client := &http.Client{ Timeout: 5 * time.Second, } // 发送HEAD请求来检查URL是否存在 resp, err := client.Head(url) if err != nil { return false } defer resp.Body.Close() // 检查状态码,2xx和3xx表示成功 return resp.StatusCode >= 200 && resp.StatusCode < 400 } // handleLogsStats 处理日志统计请求 func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 构建缓存键 cacheKey := "logs_stats" // 如果启用缓存,先尝试从缓存获取 if s.cacheEnabled { if cachedStats, found := s.statsCache.Get(cacheKey); found { // 缓存命中,直接返回 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(cachedStats) return } } // 缓存未命中,获取最新统计数据 logStats := s.dnsServer.GetQueryStats() // 存入缓存 if s.cacheEnabled { s.statsCache.Set(cacheKey, logStats) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(logStats) } // handleLogsQuery 处理日志查询请求 func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取查询参数 limit := 100 // 默认返回 100 条日志 offset := 0 sortField := r.URL.Query().Get("sort") sortDirection := r.URL.Query().Get("direction") resultFilter := r.URL.Query().Get("result") searchTerm := r.URL.Query().Get("search") queryType := r.URL.Query().Get("queryType") if limitStr := r.URL.Query().Get("limit"); limitStr != "" { fmt.Sscanf(limitStr, "%d", &limit) } if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { fmt.Sscanf(offsetStr, "%d", &offset) } // 构建缓存键,包含所有查询参数 // 已禁用缓存,每次都从数据库获取最新数据 // cacheKey := fmt.Sprintf("logs_%d_%d_%s_%s_%s_%s_%s", limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType) // 缓存未命中,获取日志数据(已禁用缓存) logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType) // 存入缓存(已禁用) // if s.cacheEnabled { // s.queryCache.Set(cacheKey, logs) // } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(logs) } // handleLogsCount 处理日志总数请求 func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取过滤参数 resultFilter := r.URL.Query().Get("result") searchTerm := r.URL.Query().Get("search") queryType := r.URL.Query().Get("queryType") // 构建缓存键(已禁用) // cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType) // 缓存未命中,获取带过滤条件的日志总数(已禁用缓存) count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType) // 存入缓存(已禁用) // if s.cacheEnabled { // s.queryCache.Set(cacheKey, count) // } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]int{"count": count}) } // handleDomainInfo 处理域名信息查询请求 func (s *Server) handleDomainInfo(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求体 var req struct { Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.Domain == "" { http.Error(w, "Domain parameter is required", http.StatusBadRequest) return } // 从域名信息数据库中查询 domainInfo, err := domain.GetDomainInfo(req.Domain) if err != nil { http.Error(w, "Failed to query domain info", http.StatusInternalServerError) return } // 返回域名信息 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(domainInfo) } // handleRestart 处理重启服务请求 func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } logger.Info("收到重启服务请求") // 停止DNS服务器 s.dnsServer.Stop() // 重新加载屏蔽规则 if err := s.shieldManager.LoadRules(); err != nil { logger.Error("重新加载屏蔽规则失败", "error", err) } // 重新启动DNS服务器 go func() { if err := s.dnsServer.Start(); err != nil { logger.Error("DNS服务器重启失败", "error", err) } }() // 重新启动定时更新任务 s.shieldManager.StopAutoUpdate() s.shieldManager.StartAutoUpdate() // 返回成功响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "服务已重启"}) logger.Info("服务重启成功") } // handleLogin 处理登录请求 func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求体 var loginData struct { Username string `json:"username"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&loginData); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"}) return } // 验证用户名和密码 if loginData.Username != s.config.Username || loginData.Password != s.config.Password { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "用户名或密码错误"}) return } // 生成会话ID sessionID := fmt.Sprintf("%d_%d", time.Now().UnixNano(), len(s.sessions)) // 保存会话 s.sessionsMutex.Lock() s.sessions[sessionID] = time.Now().Add(s.sessionTTL) s.sessionsMutex.Unlock() // 设置Cookie cookie := &http.Cookie{ Name: "session_id", Value: sessionID, Path: "/", Expires: time.Now().Add(s.sessionTTL), HttpOnly: true, Secure: false, // 开发环境下使用false,生产环境应使用true } http.SetCookie(w, cookie) // 返回成功响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "登录成功"}) logger.Info(fmt.Sprintf("用户 %s 登录成功", loginData.Username)) } // handleLogout 处理注销请求 func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 从Cookie中获取会话ID cookie, err := r.Cookie("session_id") if err == nil { // 删除会话 s.sessionsMutex.Lock() delete(s.sessions, cookie.Value) s.sessionsMutex.Unlock() } // 清除Cookie clearCookie := &http.Cookie{ Name: "session_id", Value: "", Path: "/", Expires: time.Unix(0, 0), HttpOnly: true, Secure: false, } http.SetCookie(w, clearCookie) // 返回成功响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "注销成功"}) logger.Info("用户注销成功") } // handleDomainInfoList 处理域名信息列表请求 func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取查询参数 query := r.URL.Query() if query.Has("domains") { // 处理域名信息,支持过滤特定域名 domainFilter := query.Get("domains") handleDomainsInfo(w, domainFilter) } else if query.Has("trackers") { // 处理跟踪器信息,支持过滤特定域名 trackerFilter := query.Get("trackers") handleTrackersInfo(w, trackerFilter) } else if query.Has("threats") { // 处理威胁域名信息,支持过滤特定域名 threatFilter := query.Get("threats") handleThreatsInfo(w, threatFilter) } else { // 直接访问 /domain-info 不提供任何内容 http.Error(w, "No content provided", http.StatusNoContent) return } } // isService 判断一个对象是否是服务(而不是分组) func isService(obj map[string]interface{}) bool { // 服务通常包含 name、url、categoryId 字段 _, hasName := obj["name"] _, hasUrl := obj["url"] _, hasCategoryId := obj["categoryId"] // 如果有 name 和 url,则认为是服务 if hasName && hasUrl { return true } // 如果有 categoryId,也认为是服务 if hasCategoryId { return true } return false } // handleAlert 处理威胁告警请求 func (s *Server) handleAlert(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: // 获取告警列表 limit := 100 offset := 0 level := r.URL.Query().Get("level") if limitStr := r.URL.Query().Get("limit"); limitStr != "" { fmt.Sscanf(limitStr, "%d", &limit) } if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { fmt.Sscanf(offsetStr, "%d", &offset) } // 获取告警列表 alerts := s.dnsServer.GetAlerts(limit, offset, level) // 构建响应 response := map[string]interface{}{ "alerts": alerts, "total": s.dnsServer.GetAlertCount(level), } json.NewEncoder(w).Encode(response) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleAlertResolve 处理威胁告警解决请求 func (s *Server) handleAlertResolve(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求体 var req struct { AlertID string `json:"alertId"` Action string `json:"action"` // blocked, allowed } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.AlertID == "" || req.Action == "" { http.Error(w, "AlertID and Action are required", http.StatusBadRequest) return } // 验证动作 if req.Action != threat.ActionBlocked && req.Action != threat.ActionAllowed { http.Error(w, "Invalid action", http.StatusBadRequest) return } // 解决告警 success := s.dnsServer.ResolveAlert(req.AlertID, req.Action) if success { json.NewEncoder(w).Encode(map[string]string{"status": "success"}) } else { http.Error(w, "Failed to resolve alert", http.StatusInternalServerError) } } // handleThreatDomain 处理威胁域名管理请求 func (s *Server) handleThreatDomain(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: // 获取所有威胁域名 threats := s.dnsServer.GetThreatDomains() json.NewEncoder(w).Encode(threats) case http.MethodPost: // 添加威胁域名 var req struct { Type string `json:"type"` Name string `json:"name"` RiskLevel int `json:"riskLevel"` Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.Domain == "" { http.Error(w, "Domain is required", http.StatusBadRequest) return } // 设置默认值 if req.Type == "" { req.Type = "未知" } if req.Name == "" { req.Name = "未知" } if req.RiskLevel == 0 { req.RiskLevel = 1 } err := s.dnsServer.AddThreatDomain(req.Type, req.Name, req.RiskLevel, req.Domain) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodDelete: // 删除威胁域名 var req struct { Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.Domain == "" { http.Error(w, "Domain is required", http.StatusBadRequest) return } err := s.dnsServer.RemoveThreatDomain(req.Domain) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // processServiceItem 递归处理服务或分组 func processServiceItem( serviceName string, service interface{}, companyLevelCompany string, domainFilter string, categories map[string]string, result *[]map[string]interface{}, ) { serviceMap, ok := service.(map[string]interface{}) if !ok { return } // 跳过 company 字段 if serviceName == "company" { return } // 判断是服务还是分组 if isService(serviceMap) { // 这是一个服务,进行处理 urlValue := serviceMap["url"] match := false // 检查是否需要过滤 if domainFilter != "" { // 检查服务名称是否包含过滤条件 if serviceName == domainFilter { match = true } else { // 检查 URL 是否包含过滤条件 switch v := urlValue.(type) { case string: if strings.Contains(v, domainFilter) { match = true } case map[string]interface{}: for _, url := range v { if urlStr, ok := url.(string); ok && strings.Contains(urlStr, domainFilter) { match = true break } } } } if !match { return } } // 确定公司名:优先使用服务级别的 company 字段,否则使用公司级别的 company 字段 itemCompany := companyLevelCompany if serviceCompany, ok := serviceMap["company"].(string); ok { itemCompany = serviceCompany } // 构建响应对象 item := map[string]interface{}{ "icon": serviceMap["icon"], "name": serviceMap["name"], "company": itemCompany, } // 添加类别 if categoryId, ok := serviceMap["categoryId"].(float64); ok { categoryIdStr := fmt.Sprintf("%.0f", categoryId) if category, exists := categories[categoryIdStr]; exists { item["category"] = category } } *result = append(*result, item) } else { // 这是一个分组,递归处理其下的子项 for subName, subService := range serviceMap { processServiceItem(subName, subService, companyLevelCompany, domainFilter, categories, result) } } } func handleDomainsInfo(w http.ResponseWriter, domainFilter string) { // 如果过滤参数为空字符串,返回空数组 if domainFilter == "" { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode([]map[string]interface{}{}) return } filePath := "./static/domain-info/domains/domain-info.json" data, err := os.ReadFile(filePath) if err != nil { http.Error(w, "Failed to read domain info file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("读取域名信息文件失败: %v", err)) return } // 解析JSON var domainInfo struct { Categories map[string]string `json:"categories"` Domains map[string]map[string]interface{} `json:"domains"` } if err := json.Unmarshal(data, &domainInfo); err != nil { http.Error(w, "Failed to parse domain info file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("解析域名信息文件失败: %v", err)) return } // 转换为所需格式 var result []map[string]interface{} for _, services := range domainInfo.Domains { // 获取公司级别的 company 字段 companyLevelCompany := "" if companyData, ok := services["company"].(string); ok { companyLevelCompany = companyData } // 遍历所有服务(包括嵌套的分组) for serviceName, service := range services { processServiceItem(serviceName, service, companyLevelCompany, domainFilter, domainInfo.Categories, &result) } } // 返回 JSON 响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTrackersInfo 处理跟踪器信息请求,返回名称、类别、url、所属单位/公司 func handleTrackersInfo(w http.ResponseWriter, trackerFilter string) { // 如果过滤参数为空字符串,返回空数组 if trackerFilter == "" { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode([]map[string]interface{}{}) return } filePath := "./static/domain-info/tracker/trackers.json" data, err := os.ReadFile(filePath) if err != nil { http.Error(w, "Failed to read trackers file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("读取跟踪器文件失败: %v", err)) return } // 解析JSON var trackersInfo struct { Categories map[string]string `json:"categories"` Trackers map[string]map[string]interface{} `json:"trackers"` } if err := json.Unmarshal(data, &trackersInfo); err != nil { http.Error(w, "Failed to parse trackers file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("解析跟踪器文件失败: %v", err)) return } // 转换为所需格式 var result []map[string]interface{} for trackerDomain, tracker := range trackersInfo.Trackers { // 检查是否需要过滤 if trackerFilter != "" { // 检查跟踪器域名是否包含过滤条件 if !strings.Contains(trackerDomain, trackerFilter) { // 检查名称是否包含过滤条件 if name, ok := tracker["name"].(string); !ok || !strings.Contains(name, trackerFilter) { // 检查URL是否包含过滤条件 if url, ok := tracker["url"].(string); !ok || !strings.Contains(url, trackerFilter) { continue } } } } item := map[string]interface{}{ "name": tracker["name"], "url": tracker["url"], "company": tracker["companyId"], } // 添加类别 if categoryId, ok := tracker["categoryId"].(float64); ok { categoryIdStr := fmt.Sprintf("%.0f", categoryId) if category, exists := trackersInfo.Categories[categoryIdStr]; exists { item["category"] = category } } result = append(result, item) } // 返回JSON响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleThreatsInfo 处理威胁域名信息请求,返回类型、名称、级别、域名 func handleThreatsInfo(w http.ResponseWriter, threatFilter string) { // 如果过滤参数为空字符串,返回空数组 if threatFilter == "" { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode([]map[string]string{}) return } filePath := "./static/domain-info/threats/threats-database.csv" data, err := os.ReadFile(filePath) if err != nil { http.Error(w, "Failed to read threats file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("读取威胁域名文件失败: %v", err)) return } // 解析CSV reader := csv.NewReader(bytes.NewReader(data)) reader.FieldsPerRecord = -1 // 允许不同长度的记录 // 读取所有记录 records, err := reader.ReadAll() if err != nil { http.Error(w, "Failed to parse threats file", http.StatusInternalServerError) logger.Error(fmt.Sprintf("解析威胁域名文件失败: %v", err)) return } // 转换为所需格式 var result []map[string]string // 跳过标题行 for i, record := range records { if i == 0 { continue } if len(record) >= 4 { // 检查是否需要过滤 if threatFilter != "" { // 检查域名是否包含过滤条件 if !strings.Contains(record[3], threatFilter) { // 检查名称是否包含过滤条件 if !strings.Contains(record[1], threatFilter) { // 检查类型是否包含过滤条件 if !strings.Contains(record[0], threatFilter) { continue } } } } item := map[string]string{ "type": record[0], "name": record[1], "level": record[2], "domain": record[3], } result = append(result, item) } } // 返回JSON响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleChangePassword 处理修改密码请求 func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 解析请求体 var changePasswordData struct { CurrentPassword string `json:"currentPassword"` NewPassword string `json:"newPassword"` } if err := json.NewDecoder(r.Body).Decode(&changePasswordData); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"}) return } // 验证当前密码 if changePasswordData.CurrentPassword != s.config.Password { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]string{"error": "当前密码错误"}) return } // 更新密码 s.config.Password = changePasswordData.NewPassword // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil { logger.Error("保存配置文件失败", "error", err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "保存密码失败"}) return } // 返回成功响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"}) logger.Info("密码修改成功") } // handleThreatQuery 处理威胁域名查询请求 // @Summary 查询威胁域名信息 // @Description 根据传入的域名参数查询威胁数据库,返回威胁类型、名称、风险等级和域名 // @Tags threat // @Accept json // @Produce json // @Param domain query string true "要查询的域名" // @Success 200 {string} string "威胁信息,格式:类型,名称,风险等级,域名" // @Failure 400 {object} map[string]string "缺少域名参数" // @Router /api/threat [get] func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取域名参数 domain := r.URL.Query().Get("domain") if domain == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "需要提供 domain 参数"}) return } // 读取威胁数据库 CSV 文件 filePath := "./static/domain-info/threats/threats-database.csv" data, err := os.ReadFile(filePath) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"}) logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err)) return } // 解析 CSV reader := csv.NewReader(bytes.NewReader(data)) reader.FieldsPerRecord = -1 // 允许不同长度的记录 // 读取所有记录 records, err := reader.ReadAll() if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"}) logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err)) return } // 构建威胁域名映射(支持顶级域名匹配) threatMap := make(map[string][]string) for i, record := range records { if i == 0 { continue // 跳过标题行 } if len(record) >= 4 { threatType := record[0] // 第一列:类型 threatName := record[1] // 第二列:名称 riskLevel := record[2] // 第三列:风险等级 domain := record[3] // 第四列:域名 threatInfo := []string{threatType, threatName, riskLevel} // 1. 完整域名匹配(所有类型都添加) threatMap[domain] = threatInfo // 2. 只有恶意网站类型才添加子域名匹配规则 // 类型判断:钓鱼网站、仿冒网站 // 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配 if threatType == "钓鱼网站" || threatType == "仿冒网站" { // 对于恶意网站,添加子域名匹配规则 // 例如:sub.example.com -> 添加 .sub.example.com 规则 // 这样 a.sub.example.com 就会匹配 topLevelDomain := "." + domain // 只有当该顶级域名规则不存在时才添加 if _, exists := threatMap[topLevelDomain]; !exists { threatMap[topLevelDomain] = threatInfo } } } } // 查询单个域名 var result string // 1. 先检查完整匹配 if threat, exists := threatMap[domain]; exists { result = fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain) } else { // 2. 检查子域名匹配(遍历顶级域名规则) for threatDomain, threatInfo := range threatMap { // 只检查以点开头的顶级域名规则 if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) { // 额外验证:确保是完整的域名部分匹配 prefix := strings.TrimSuffix(domain, threatDomain) if len(prefix) > 0 && !strings.HasSuffix(prefix, ".") { // 不是完整的子域名部分,跳过 continue } result = fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain) break } } } w.Header().Set("Content-Type", "application/json") if result == "" { // 未找到匹配的威胁信息 json.NewEncoder(w).Encode(map[string]string{"message": "无"}) } else { // 返回威胁信息 json.NewEncoder(w).Encode(map[string]string{"data": result}) } } // handleThreatBatch 批量查询威胁域名 // @Summary 批量查询威胁域名 // @Description 批量查询多个域名是否是威胁域名 // @Tags threat // @Accept json // @Produce json // @Param domains body []string true "域名列表" // @Success 200 {object} map[string]interface{} "批量查询结果" // @Router /api/threat/batch [post] func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req struct { Domains []string `json:"domains"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]string{"error": "请求格式错误"}) return } // 读取威胁数据库 CSV 文件 filePath := "./static/domain-info/threats/threats-database.csv" data, err := os.ReadFile(filePath) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"}) logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err)) return } // 解析 CSV reader := csv.NewReader(bytes.NewReader(data)) reader.FieldsPerRecord = -1 // 允许不同长度的记录 // 读取所有记录 records, err := reader.ReadAll() if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"}) logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err)) return } // 构建威胁域名映射(支持顶级域名匹配) threatMap := make(map[string][]string) for i, record := range records { if i == 0 { continue // 跳过标题行 } if len(record) >= 4 { threatType := record[0] // 第一列:类型 threatName := record[1] // 第二列:名称 riskLevel := record[2] // 第三列:风险等级 domain := record[3] // 第四列:域名 threatInfo := []string{threatType, threatName, riskLevel} // 1. 完整域名匹配(所有类型都添加) threatMap[domain] = threatInfo // 2. 只有恶意网站类型才添加子域名匹配规则 // 类型判断:钓鱼网站、仿冒网站 // 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配 if threatType == "钓鱼网站" || threatType == "仿冒网站" { // 对于恶意网站,添加子域名匹配规则 // 例如:sub.example.com -> 添加 .sub.example.com 规则 // 这样 a.sub.example.com 就会匹配 topLevelDomain := "." + domain // 只有当该顶级域名规则不存在时才添加 if _, exists := threatMap[topLevelDomain]; !exists { threatMap[topLevelDomain] = threatInfo } } } } // 批量查询 results := make([]map[string]interface{}, 0, len(req.Domains)) for _, domain := range req.Domains { // 1. 先检查完整匹配 if threat, exists := threatMap[domain]; exists { results = append(results, map[string]interface{}{ "domain": domain, "isThreat": true, "data": fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain), }) continue } // 2. 检查子域名匹配(遍历顶级域名规则) matched := false for threatDomain, threatInfo := range threatMap { // 只检查以点开头的顶级域名规则 if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) { // 验证:确保是有效的子域名匹配 // 例如:test.example.com 匹配 .example.com ✅ // notexample.com 不应该匹配 .example.com ❌ // 去掉 threatDomain 的第一个字符(即去掉开头的点) suffixToTrim := threatDomain[1:] prefix := strings.TrimSuffix(domain, suffixToTrim) // 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配) if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) { results = append(results, map[string]interface{}{ "domain": domain, "isThreat": true, "data": fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain), }) matched = true break } } } if !matched { results = append(results, map[string]interface{}{ "domain": domain, "isThreat": false, }) } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "results": results, }) }