Files
dns-server/http/server.go
2025-11-23 18:37:24 +08:00

468 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package http
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"dns-server/config"
"dns-server/dns"
"dns-server/logger"
"dns-server/shield"
)
// Server HTTP控制台服务器
type Server struct {
globalConfig *config.Config
config *config.HTTPConfig
dnsServer *dns.Server
shieldManager *shield.ShieldManager
server *http.Server
}
// NewServer 创建HTTP服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
return &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
}
}
// Start 启动HTTP服务器
func (s *Server) Start() error {
mux := http.NewServeMux()
// API路由
if s.config.EnableAPI {
mux.HandleFunc("/api/stats", s.handleStats)
mux.HandleFunc("/api/shield", s.handleShield)
mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts)
mux.HandleFunc("/api/query", s.handleQuery)
mux.HandleFunc("/api/status", s.handleStatus)
mux.HandleFunc("/api/config", s.handleConfig)
// 添加统计相关接口
mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains)
mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains)
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
}
// 静态文件服务(可后续添加前端界面)
mux.Handle("/", http.FileServer(http.Dir("./static")))
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
logger.Info(fmt.Sprintf("HTTP控制台服务器启动监听地址: %s:%d", s.config.Host, s.config.Port))
return s.server.ListenAndServe()
}
// Stop 停止HTTP服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Close()
}
logger.Info("HTTP控制台服务器已停止")
}
// handleStats 处理统计信息请求
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
stats := map[string]interface{}{
"dns": dnsStats,
"shield": shieldStats,
"time": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopResolvedDomains 处理获取最常解析的域名请求
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopResolvedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleRecentBlockedDomains 处理最近屏蔽域名请求
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetRecentBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"time": domain.LastSeen.Format("15:04:05"),
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleHourlyStats 处理24小时统计请求
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
hourlyStats := s.dnsServer.GetHourlyStats()
// 生成最近24小时的数据
now := time.Now()
labels := make([]string, 24)
data := make([]int64, 24)
for i := 23; i >= 0; i-- {
hour := now.Add(time.Duration(-i) * time.Hour)
hourKey := hour.Format("2006-01-02-15")
labels[23-i] = hour.Format("15:00")
data[23-i] = hourlyStats[hourKey]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleShield 处理屏蔽规则管理请求
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 处理hosts管理子路由
if strings.HasPrefix(r.URL.Path, "/api/shield/hosts") {
s.handleShieldHosts(w, r)
return
}
switch r.Method {
case http.MethodGet:
// 获取完整规则列表
rules := s.shieldManager.GetRules()
json.NewEncoder(w).Encode(rules)
case http.MethodPost:
// 添加屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodPut:
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleShieldHosts 处理hosts管理请求
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
// 添加hosts条目
var req struct {
IP string `json:"ip"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" || req.Domain == "" {
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除hosts条目
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodGet:
// 获取hosts条目列表
// 注意这需要在shieldManager中添加一个获取所有hosts条目的方法
// 暂时返回统计信息
stats := s.shieldManager.GetStats()
json.NewEncoder(w).Encode(map[string]interface{}{
"hostsCount": stats["hostsRules"],
"message": "获取hosts列表功能待实现",
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleQuery 处理DNS查询请求
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
// 检查域名是否被屏蔽
blocked := s.shieldManager.IsBlocked(domain)
// 检查hosts文件是否有匹配
hostsIP, hasHosts := s.shieldManager.GetHostsIP(domain)
result := map[string]interface{}{
"domain": domain,
"blocked": blocked,
"hasHosts": hasHosts,
"hostsIP": hostsIP,
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleStatus 处理系统状态请求
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := s.dnsServer.GetStats()
status := map[string]interface{}{
"status": "running",
"queries": stats.Queries,
"lastQuery": stats.LastQuery,
"uptime": time.Since(stats.LastQuery),
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
// handleConfig 处理配置请求
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 返回当前配置(只返回必要的部分)
config := map[string]interface{}{
"shield": map[string]string{
"blockMethod": s.globalConfig.Shield.BlockMethod,
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
},
}
json.NewEncoder(w).Encode(config)
case http.MethodPost:
// 更新配置
var req struct {
Shield struct {
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
} `json:"shield"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 更新屏蔽配置
if req.Shield.BlockMethod != "" {
// 验证屏蔽方法是否有效
validMethods := map[string]bool{
"NXDOMAIN": true,
"refused": true,
"emptyIP": true,
"customIP": true,
}
if !validMethods[req.Shield.BlockMethod] {
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
return
}
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
// 如果选择了customIP验证IP地址
if req.Shield.BlockMethod == "customIP" {
if req.Shield.CustomBlockIP == "" {
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
return
}
// 简单的IP地址验证
if !isValidIP(req.Shield.CustomBlockIP) {
http.Error(w, "无效的IP地址", http.StatusBadRequest)
return
}
}
}
if req.Shield.CustomBlockIP != "" {
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
}
// 返回成功响应
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "配置已更新",
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// isValidIP 简单验证IP地址格式
func isValidIP(ip string) bool {
// 简单的IPv4地址验证
parts := strings.Split(ip, ".")
if len(parts) != 4 {
return false
}
for _, part := range parts {
// 检查是否为数字
for _, char := range part {
if char < '0' || char > '9' {
return false
}
}
// 检查数字范围
var num int
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
return false
}
}
return true
}