package http import ( "encoding/json" "fmt" "io/ioutil" "net/http" "sort" "strings" "sync" "time" "dns-server/config" "dns-server/dns" "dns-server/logger" "dns-server/shield" "github.com/gorilla/websocket" ) // Server HTTP控制台服务器 type Server struct { globalConfig *config.Config config *config.HTTPConfig dnsServer *dns.Server shieldManager *shield.ShieldManager server *http.Server // WebSocket相关字段 upgrader websocket.Upgrader clients map[*websocket.Conn]bool clientsMutex sync.Mutex broadcastChan chan []byte } // NewServer 创建HTTP服务器实例 func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server { server := &Server{ globalConfig: globalConfig, config: &globalConfig.HTTP, dnsServer: dnsServer, shieldManager: shieldManager, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // 允许所有CORS请求 CheckOrigin: func(r *http.Request) bool { return true }, }, clients: make(map[*websocket.Conn]bool), broadcastChan: make(chan []byte, 100), } // 启动广播协程 go server.startBroadcastLoop() return server } // Start 启动HTTP服务器 func (s *Server) Start() error { mux := http.NewServeMux() // API路由 if s.config.EnableAPI { mux.HandleFunc("/api/stats", s.handleStats) mux.HandleFunc("/api/shield", s.handleShield) mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts) mux.HandleFunc("/api/shield/blacklists", s.handleShieldBlacklists) mux.HandleFunc("/api/query", s.handleQuery) mux.HandleFunc("/api/status", s.handleStatus) mux.HandleFunc("/api/config", s.handleConfig) mux.HandleFunc("/api/config/restart", s.handleRestart) // 添加统计相关接口 mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains) mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains) mux.HandleFunc("/api/top-clients", s.handleTopClients) mux.HandleFunc("/api/top-domains", s.handleTopDomains) mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains) mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats) mux.HandleFunc("/api/daily-stats", s.handleDailyStats) mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats) mux.HandleFunc("/api/query/type", s.handleQueryTypeStats) // WebSocket端点 mux.HandleFunc("/ws/stats", s.handleWebSocketStats) } // 自定义静态文件服务处理器,用于禁用浏览器缓存 fileServer := http.FileServer(http.Dir("./static")) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // 添加Cache-Control头,禁用浏览器缓存 w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT") // 使用StripPrefix处理路径 http.StripPrefix("/", fileServer).ServeHTTP(w, r) }) s.server = &http.Server{ Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), Handler: mux, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port)) return s.server.ListenAndServe() } // Stop 停止HTTP服务器 func (s *Server) Stop() { if s.server != nil { s.server.Close() } logger.Info("HTTP控制台服务器已停止") } // handleStats 处理统计信息请求 func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } dnsStats := s.dnsServer.GetStats() shieldStats := s.shieldManager.GetStats() // 获取最常用查询类型(如果有) topQueryType := "-" maxCount := int64(0) if len(dnsStats.QueryTypes) > 0 { for queryType, count := range dnsStats.QueryTypes { if count > maxCount { maxCount = count topQueryType = queryType } } } // 获取活跃来源IP数量 activeIPCount := len(dnsStats.SourceIPs) // 格式化平均响应时间为两位小数 formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100 // 构建响应数据,确保所有字段都反映服务器的真实状态 stats := map[string]interface{}{ "dns": map[string]interface{}{ "Queries": dnsStats.Queries, "Blocked": dnsStats.Blocked, "Allowed": dnsStats.Allowed, "Errors": dnsStats.Errors, "LastQuery": dnsStats.LastQuery, "AvgResponseTime": formattedResponseTime, "TotalResponseTime": dnsStats.TotalResponseTime, "QueryTypes": dnsStats.QueryTypes, "SourceIPs": dnsStats.SourceIPs, "CpuUsage": dnsStats.CpuUsage, }, "shield": shieldStats, "topQueryType": topQueryType, "activeIPs": activeIPCount, "avgResponseTime": formattedResponseTime, "cpuUsage": dnsStats.CpuUsage, "time": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(stats) } // WebSocket相关方法 // handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据 func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) { // 升级HTTP连接为WebSocket conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err)) return } defer conn.Close() // 将新客户端添加到客户端列表 s.clientsMutex.Lock() s.clients[conn] = true clientCount := len(s.clients) s.clientsMutex.Unlock() logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount)) // 发送初始数据 if err := s.sendInitialStats(conn); err != nil { logger.Error(fmt.Sprintf("发送初始数据失败: %v", err)) return } // 定期发送更新数据 ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化 defer ticker.Stop() // 最后一次发送的数据快照,用于检测变化 var lastStats map[string]interface{} // 保持连接并定期发送数据 for { select { case <-ticker.C: // 获取最新统计数据 currentStats := s.buildStatsData() // 检查数据是否有变化 if !s.areStatsEqual(lastStats, currentStats) { // 数据有变化,发送更新 data, err := json.Marshal(map[string]interface{}{ "type": "stats_update", "data": currentStats, "time": time.Now(), }) if err != nil { logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err)) continue } if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err)) return } // 更新最后发送的数据 lastStats = currentStats } case <-r.Context().Done(): // 客户端断开连接 s.clientsMutex.Lock() delete(s.clients, conn) clientCount := len(s.clients) s.clientsMutex.Unlock() logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount)) return } } } // sendInitialStats 发送初始统计数据 func (s *Server) sendInitialStats(conn *websocket.Conn) error { stats := s.buildStatsData() data, err := json.Marshal(map[string]interface{}{ "type": "initial_data", "data": stats, "time": time.Now(), }) if err != nil { return err } return conn.WriteMessage(websocket.TextMessage, data) } // buildStatsData 构建统计数据 func (s *Server) buildStatsData() map[string]interface{} { dnsStats := s.dnsServer.GetStats() shieldStats := s.shieldManager.GetStats() // 获取最常用查询类型 topQueryType := "-" maxCount := int64(0) if len(dnsStats.QueryTypes) > 0 { for queryType, count := range dnsStats.QueryTypes { if count > maxCount { maxCount = count topQueryType = queryType } } } // 获取活跃来源IP数量 activeIPCount := len(dnsStats.SourceIPs) // 格式化平均响应时间 formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100 return map[string]interface{}{ "dns": map[string]interface{}{ "Queries": dnsStats.Queries, "Blocked": dnsStats.Blocked, "Allowed": dnsStats.Allowed, "Errors": dnsStats.Errors, "LastQuery": dnsStats.LastQuery, "AvgResponseTime": formattedResponseTime, "TotalResponseTime": dnsStats.TotalResponseTime, "QueryTypes": dnsStats.QueryTypes, "SourceIPs": dnsStats.SourceIPs, "CpuUsage": dnsStats.CpuUsage, }, "shield": shieldStats, "topQueryType": topQueryType, "activeIPs": activeIPCount, "avgResponseTime": formattedResponseTime, "cpuUsage": dnsStats.CpuUsage, } } // areStatsEqual 检查两次统计数据是否相等(用于检测变化) func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool { if stats1 == nil || stats2 == nil { return false } // 只比较关键数值,避免频繁更新 if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 { if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 { // 检查主要计数器 if dns1["Queries"] != dns2["Queries"] || dns1["Blocked"] != dns2["Blocked"] || dns1["Allowed"] != dns2["Allowed"] || dns1["Errors"] != dns2["Errors"] { return false } } } return true } // startBroadcastLoop 启动广播循环 func (s *Server) startBroadcastLoop() { for message := range s.broadcastChan { s.clientsMutex.Lock() for client := range s.clients { if err := client.WriteMessage(websocket.TextMessage, message); err != nil { logger.Error(fmt.Sprintf("广播消息失败: %v", err)) client.Close() delete(s.clients, client) } } s.clientsMutex.Unlock() } } // handleTopBlockedDomains 处理TOP屏蔽域名请求 func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetTopBlockedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopResolvedDomains 处理获取最常解析的域名请求 func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetTopResolvedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleRecentBlockedDomains 处理最近屏蔽域名请求 func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetRecentBlockedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "time": domain.LastSeen.Format("15:04:05"), } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleHourlyStats 处理24小时统计请求 func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } hourlyStats := s.dnsServer.GetHourlyStats() // 生成最近24小时的数据 now := time.Now() labels := make([]string, 24) data := make([]int64, 24) for i := 23; i >= 0; i-- { hour := now.Add(time.Duration(-i) * time.Hour) hourKey := hour.Format("2006-01-02-15") labels[23-i] = hour.Format("15:00") data[23-i] = hourlyStats[hourKey] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleDailyStats 处理每日统计数据请求 func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取每日统计数据 dailyStats := s.dnsServer.GetDailyStats() // 生成过去7天的时间标签 labels := make([]string, 7) data := make([]int64, 7) now := time.Now() for i := 6; i >= 0; i-- { t := now.AddDate(0, 0, -i) key := t.Format("2006-01-02") labels[6-i] = t.Format("01-02") data[6-i] = dailyStats[key] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleMonthlyStats 处理每月统计数据请求 func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取每日统计数据(用于30天视图) dailyStats := s.dnsServer.GetDailyStats() // 生成过去30天的时间标签 labels := make([]string, 30) data := make([]int64, 30) now := time.Now() for i := 29; i >= 0; i-- { t := now.AddDate(0, 0, -i) key := t.Format("2006-01-02") labels[29-i] = t.Format("01-02") data[29-i] = dailyStats[key] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleQueryTypeStats 处理查询类型统计请求 func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取DNS统计数据 dnsStats := s.dnsServer.GetStats() // 转换为前端需要的格式 result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes)) for queryType, count := range dnsStats.QueryTypes { result = append(result, map[string]interface{}{ "type": queryType, "count": count, }) } // 按计数降序排序 sort.Slice(result, func(i, j int) bool { return result[i]["count"].(int64) > result[j]["count"].(int64) }) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopClients 处理TOP客户端请求 func (s *Server) handleTopClients(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取TOP客户端列表 clients := s.dnsServer.GetTopClients(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(clients)) for i, client := range clients { result[i] = map[string]interface{}{ "ip": client.IP, "count": client.Count, "lastSeen": client.LastSeen, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopDomains 处理TOP域名请求 func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // 获取TOP被屏蔽域名 blockedDomains := s.dnsServer.GetTopBlockedDomains(10) // 获取TOP已解析域名 resolvedDomains := s.dnsServer.GetTopResolvedDomains(10) // 合并并去重域名统计 domainMap := make(map[string]int64) for _, domain := range blockedDomains { domainMap[domain.Domain] += domain.Count } for _, domain := range resolvedDomains { domainMap[domain.Domain] += domain.Count } // 转换为切片并排序 domainList := make([]map[string]interface{}, 0, len(domainMap)) for domain, count := range domainMap { domainList = append(domainList, map[string]interface{}{ "domain": domain, "count": count, }) } // 按计数降序排序 sort.Slice(domainList, func(i, j int) bool { return domainList[i]["count"].(int64) > domainList[j]["count"].(int64) }) // 返回限制数量 if len(domainList) > 10 { domainList = domainList[:10] } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(domainList) } // handleShield 处理屏蔽规则管理请求 func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表 switch r.Method { case http.MethodGet: // 获取规则统计信息 stats := s.shieldManager.GetStats() shieldInfo := map[string]interface{}{ "updateInterval": s.globalConfig.Shield.UpdateInterval, "blockMethod": s.globalConfig.Shield.BlockMethod, "blacklistCount": len(s.globalConfig.Shield.Blacklists), "domainRulesCount": stats["domainRules"], "domainExceptionsCount": stats["domainExceptions"], "regexRulesCount": stats["regexRules"], "regexExceptionsCount": stats["regexExceptions"], "hostsRulesCount": stats["hostsRules"], } json.NewEncoder(w).Encode(shieldInfo) return case http.MethodPost: // 添加屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.AddRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return case http.MethodDelete: // 删除屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return case http.MethodPut: // 重新加载规则 if err := s.shieldManager.LoadRules(); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"}) return default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } } // handleShieldBlacklists 处理远程黑名单管理请求 func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // 处理更新单个黑名单 if strings.Contains(r.URL.Path, "/update") { if r.Method == http.MethodPost { // 提取黑名单URL或Name parts := strings.Split(r.URL.Path, "/") var targetURLOrName string for i, part := range parts { if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" { targetURLOrName = parts[i+1] break } } if targetURLOrName == "" { http.Error(w, "黑名单标识不能为空", http.StatusBadRequest) return } // 获取黑名单列表 blacklists := s.shieldManager.GetBlacklists() var targetIndex = -1 for i, list := range blacklists { if list.URL == targetURLOrName || list.Name == targetURLOrName { targetIndex = i break } } if targetIndex == -1 { http.Error(w, "黑名单不存在", http.StatusNotFound) return } // 更新时间戳 blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339) // 保存更新后的黑名单列表 s.shieldManager.UpdateBlacklist(blacklists) // 重新加载规则以获取最新的远程规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return } } // 处理删除黑名单 parts := strings.Split(r.URL.Path, "/") if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete { id := parts[4] blacklists := s.shieldManager.GetBlacklists() var newBlacklists []config.BlacklistEntry for _, list := range blacklists { if list.URL != id && list.Name != id { newBlacklists = append(newBlacklists, list) } } s.shieldManager.UpdateBlacklist(newBlacklists) json.NewEncoder(w).Encode(map[string]string{"status": "success"}) return } switch r.Method { case http.MethodGet: // 获取远程黑名单列表 blacklists := s.shieldManager.GetBlacklists() json.NewEncoder(w).Encode(blacklists) case http.MethodPost: // 添加远程黑名单 var req struct { Name string `json:"name"` URL string `json:"url"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.Name == "" || req.URL == "" { http.Error(w, "Name and URL are required", http.StatusBadRequest) return } // 获取现有黑名单 blacklists := s.shieldManager.GetBlacklists() // 检查是否已存在 for _, list := range blacklists { if list.URL == req.URL { http.Error(w, "黑名单URL已存在", http.StatusConflict) return } } // 添加新黑名单 newEntry := config.BlacklistEntry{ Name: req.Name, URL: req.URL, Enabled: true, } blacklists = append(blacklists, newEntry) s.shieldManager.UpdateBlacklist(blacklists) // 重新加载规则以获取新添加的远程规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodPut: // 更新所有远程黑名单 blacklists := s.shieldManager.GetBlacklists() for i := range blacklists { // 更新每个黑名单的时间戳 blacklists[i].LastUpdateTime = time.Now().Format(time.RFC3339) } s.shieldManager.UpdateBlacklist(blacklists) // 重新加载所有规则 s.shieldManager.LoadRules() json.NewEncoder(w).Encode(map[string]string{"status": "success"}) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleShieldHosts 处理hosts管理请求 func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: // 添加hosts条目 var req struct { IP string `json:"ip"` Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.IP == "" || req.Domain == "" { http.Error(w, "IP and Domain are required", http.StatusBadRequest) return } if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodDelete: // 删除hosts条目 var req struct { Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodGet: // 获取hosts条目列表 hosts := s.shieldManager.GetAllHosts() hostsCount := s.shieldManager.GetHostsCount() // 转换为数组格式,便于前端展示 hostsList := make([]map[string]string, 0, len(hosts)) for domain, ip := range hosts { hostsList = append(hostsList, map[string]string{ "domain": domain, "ip": ip, }) } json.NewEncoder(w).Encode(map[string]interface{}{ "hosts": hostsList, "hostsCount": hostsCount, }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleQuery 处理DNS查询请求 func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domain := r.URL.Query().Get("domain") if domain == "" { http.Error(w, "Domain parameter is required", http.StatusBadRequest) return } // 获取域名屏蔽的详细信息 blockDetails := s.shieldManager.CheckDomainBlockDetails(domain) // 添加时间戳 blockDetails["timestamp"] = time.Now() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(blockDetails) } // handleStatus 处理系统状态请求 func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } stats := s.dnsServer.GetStats() // 使用服务器的实际启动时间计算准确的运行时间 serverStartTime := s.dnsServer.GetStartTime() uptime := time.Since(serverStartTime) // 构建包含所有真实服务器统计数据的响应 status := map[string]interface{}{ "status": "running", "queries": stats.Queries, "blocked": stats.Blocked, "allowed": stats.Allowed, "errors": stats.Errors, "lastQuery": stats.LastQuery, "avgResponseTime": stats.AvgResponseTime, "activeIPs": len(stats.SourceIPs), "startTime": serverStartTime, "uptime": uptime.Milliseconds(), // 转换为毫秒数,方便前端处理 "cpuUsage": stats.CpuUsage, "timestamp": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } // saveConfigToFile 保存配置到文件 func saveConfigToFile(config *config.Config, filePath string) error { data, err := json.MarshalIndent(config, "", " ") if err != nil { return err } return ioutil.WriteFile(filePath, data, 0644) } // handleConfig 处理配置请求 func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: // 返回当前配置(包括黑名单配置) config := map[string]interface{}{ "shield": map[string]interface{}{ "blockMethod": s.globalConfig.Shield.BlockMethod, "customBlockIP": s.globalConfig.Shield.CustomBlockIP, "blacklists": s.globalConfig.Shield.Blacklists, "updateInterval": s.globalConfig.Shield.UpdateInterval, }, } json.NewEncoder(w).Encode(config) case http.MethodPost: // 更新配置 var req struct { Shield struct { BlockMethod string `json:"blockMethod"` CustomBlockIP string `json:"customBlockIP"` Blacklists []config.BlacklistEntry `json:"blacklists"` UpdateInterval int `json:"updateInterval"` } `json:"shield"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "无效的请求体", http.StatusBadRequest) return } // 更新屏蔽配置 if req.Shield.BlockMethod != "" { // 验证屏蔽方法是否有效 validMethods := map[string]bool{ "NXDOMAIN": true, "refused": true, "emptyIP": true, "customIP": true, } if !validMethods[req.Shield.BlockMethod] { http.Error(w, "无效的屏蔽方法", http.StatusBadRequest) return } s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod // 如果选择了customIP,验证IP地址 if req.Shield.BlockMethod == "customIP" { if req.Shield.CustomBlockIP == "" { http.Error(w, "自定义IP不能为空", http.StatusBadRequest) return } // 简单的IP地址验证 if !isValidIP(req.Shield.CustomBlockIP) { http.Error(w, "无效的IP地址", http.StatusBadRequest) return } } } if req.Shield.CustomBlockIP != "" { s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP } // 更新黑名单配置 if req.Shield.Blacklists != nil { // 验证黑名单配置 for i, bl := range req.Shield.Blacklists { if bl.URL == "" { http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest) return } if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") { http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest) return } } s.globalConfig.Shield.Blacklists = req.Shield.Blacklists s.shieldManager.UpdateBlacklist(req.Shield.Blacklists) // 重新加载规则 if err := s.shieldManager.LoadRules(); err != nil { logger.Error("重新加载规则失败", "error", err) } } // 更新更新间隔 if req.Shield.UpdateInterval > 0 { s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval // 重新启动自动更新 s.shieldManager.StopAutoUpdate() s.shieldManager.StartAutoUpdate() } // 保存配置到文件 if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil { logger.Error("保存配置到文件失败", "error", err) // 不返回错误,只记录日志,因为配置已经在内存中更新成功 } // 返回成功响应 json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "message": "配置已更新", }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // isValidIP 简单验证IP地址格式 func isValidIP(ip string) bool { // 简单的IPv4地址验证 parts := strings.Split(ip, ".") if len(parts) != 4 { return false } for _, part := range parts { // 检查是否为数字 for _, char := range part { if char < '0' || char > '9' { return false } } // 检查数字范围 var num int if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 { return false } } return true } // handleRestart 处理重启服务请求 func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } logger.Info("收到重启服务请求") // 停止DNS服务器 s.dnsServer.Stop() // 重新加载屏蔽规则 if err := s.shieldManager.LoadRules(); err != nil { logger.Error("重新加载屏蔽规则失败", "error", err) } // 重新启动DNS服务器 go func() { if err := s.dnsServer.Start(); err != nil { logger.Error("DNS服务器重启失败", "error", err) } }() // 重新启动定时更新任务 s.shieldManager.StopAutoUpdate() s.shieldManager.StartAutoUpdate() // 返回成功响应 w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "服务已重启"}) logger.Info("服务重启成功") }