package http import ( "encoding/json" "fmt" "net/http" "strings" "time" "dns-server/config" "dns-server/dns" "dns-server/logger" "dns-server/shield" ) // Server HTTP控制台服务器 type Server struct { globalConfig *config.Config config *config.HTTPConfig dnsServer *dns.Server shieldManager *shield.ShieldManager server *http.Server } // NewServer 创建HTTP服务器实例 func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server { return &Server{ globalConfig: globalConfig, config: &globalConfig.HTTP, dnsServer: dnsServer, shieldManager: shieldManager, } } // Start 启动HTTP服务器 func (s *Server) Start() error { mux := http.NewServeMux() // API路由 if s.config.EnableAPI { mux.HandleFunc("/api/stats", s.handleStats) mux.HandleFunc("/api/shield", s.handleShield) mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts) mux.HandleFunc("/api/query", s.handleQuery) mux.HandleFunc("/api/status", s.handleStatus) mux.HandleFunc("/api/config", s.handleConfig) // 添加统计相关接口 mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains) mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains) mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains) mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats) } // 静态文件服务(可后续添加前端界面) mux.Handle("/", http.FileServer(http.Dir("./static"))) s.server = &http.Server{ Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), Handler: mux, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port)) return s.server.ListenAndServe() } // Stop 停止HTTP服务器 func (s *Server) Stop() { if s.server != nil { s.server.Close() } logger.Info("HTTP控制台服务器已停止") } // handleStats 处理统计信息请求 func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } dnsStats := s.dnsServer.GetStats() shieldStats := s.shieldManager.GetStats() stats := map[string]interface{}{ "dns": dnsStats, "shield": shieldStats, "time": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(stats) } // handleTopBlockedDomains 处理TOP屏蔽域名请求 func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetTopBlockedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleTopResolvedDomains 处理获取最常解析的域名请求 func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetTopResolvedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "count": domain.Count, } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleRecentBlockedDomains 处理最近屏蔽域名请求 func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domains := s.dnsServer.GetRecentBlockedDomains(10) // 转换为前端需要的格式 result := make([]map[string]interface{}, len(domains)) for i, domain := range domains { result[i] = map[string]interface{}{ "domain": domain.Domain, "time": domain.LastSeen.Format("15:04:05"), } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleHourlyStats 处理24小时统计请求 func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } hourlyStats := s.dnsServer.GetHourlyStats() // 生成最近24小时的数据 now := time.Now() labels := make([]string, 24) data := make([]int64, 24) for i := 23; i >= 0; i-- { hour := now.Add(time.Duration(-i) * time.Hour) hourKey := hour.Format("2006-01-02-15") labels[23-i] = hour.Format("15:00") data[23-i] = hourlyStats[hourKey] } result := map[string]interface{}{ "labels": labels, "data": data, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleShield 处理屏蔽规则管理请求 func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") // 处理hosts管理子路由 if strings.HasPrefix(r.URL.Path, "/api/shield/hosts") { s.handleShieldHosts(w, r) return } switch r.Method { case http.MethodGet: // 获取完整规则列表 rules := s.shieldManager.GetRules() json.NewEncoder(w).Encode(rules) case http.MethodPost: // 添加屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.AddRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodDelete: // 删除屏蔽规则 var req struct { Rule string `json:"rule"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveRule(req.Rule); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodPut: // 重新加载规则 if err := s.shieldManager.LoadRules(); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"}) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleShieldHosts 处理hosts管理请求 func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: // 添加hosts条目 var req struct { IP string `json:"ip"` Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if req.IP == "" || req.Domain == "" { http.Error(w, "IP and Domain are required", http.StatusBadRequest) return } if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodDelete: // 删除hosts条目 var req struct { Domain string `json:"domain"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } json.NewEncoder(w).Encode(map[string]string{"status": "success"}) case http.MethodGet: // 获取hosts条目列表 // 注意:这需要在shieldManager中添加一个获取所有hosts条目的方法 // 暂时返回统计信息 stats := s.shieldManager.GetStats() json.NewEncoder(w).Encode(map[string]interface{}{ "hostsCount": stats["hostsRules"], "message": "获取hosts列表功能待实现", }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handleQuery 处理DNS查询请求 func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } domain := r.URL.Query().Get("domain") if domain == "" { http.Error(w, "Domain parameter is required", http.StatusBadRequest) return } // 检查域名是否被屏蔽 blocked := s.shieldManager.IsBlocked(domain) // 检查hosts文件是否有匹配 hostsIP, hasHosts := s.shieldManager.GetHostsIP(domain) result := map[string]interface{}{ "domain": domain, "blocked": blocked, "hasHosts": hasHosts, "hostsIP": hostsIP, "timestamp": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(result) } // handleStatus 处理系统状态请求 func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } stats := s.dnsServer.GetStats() status := map[string]interface{}{ "status": "running", "queries": stats.Queries, "lastQuery": stats.LastQuery, "uptime": time.Since(stats.LastQuery), "timestamp": time.Now(), } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } // handleConfig 处理配置请求 func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: // 返回当前配置(只返回必要的部分) config := map[string]interface{}{ "shield": map[string]string{ "blockMethod": s.globalConfig.Shield.BlockMethod, "customBlockIP": s.globalConfig.Shield.CustomBlockIP, }, } json.NewEncoder(w).Encode(config) case http.MethodPost: // 更新配置 var req struct { Shield struct { BlockMethod string `json:"blockMethod"` CustomBlockIP string `json:"customBlockIP"` } `json:"shield"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "无效的请求体", http.StatusBadRequest) return } // 更新屏蔽配置 if req.Shield.BlockMethod != "" { // 验证屏蔽方法是否有效 validMethods := map[string]bool{ "NXDOMAIN": true, "refused": true, "emptyIP": true, "customIP": true, } if !validMethods[req.Shield.BlockMethod] { http.Error(w, "无效的屏蔽方法", http.StatusBadRequest) return } s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod // 如果选择了customIP,验证IP地址 if req.Shield.BlockMethod == "customIP" { if req.Shield.CustomBlockIP == "" { http.Error(w, "自定义IP不能为空", http.StatusBadRequest) return } // 简单的IP地址验证 if !isValidIP(req.Shield.CustomBlockIP) { http.Error(w, "无效的IP地址", http.StatusBadRequest) return } } } if req.Shield.CustomBlockIP != "" { s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP } // 返回成功响应 json.NewEncoder(w).Encode(map[string]interface{}{ "success": true, "message": "配置已更新", }) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // isValidIP 简单验证IP地址格式 func isValidIP(ip string) bool { // 简单的IPv4地址验证 parts := strings.Split(ip, ".") if len(parts) != 4 { return false } for _, part := range parts { // 检查是否为数字 for _, char := range part { if char < '0' || char > '9' { return false } } // 检查数字范围 var num int if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 { return false } } return true }