Files
dns-server/http/server.go
2025-11-26 01:11:37 +08:00

994 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package http
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"dns-server/config"
"dns-server/dns"
"dns-server/logger"
"dns-server/shield"
)
// Server HTTP控制台服务器
type Server struct {
globalConfig *config.Config
config *config.HTTPConfig
dnsServer *dns.Server
shieldManager *shield.ShieldManager
server *http.Server
// WebSocket相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
}
// NewServer 创建HTTP服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有CORS请求
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
broadcastChan: make(chan []byte, 100),
}
// 启动广播协程
go server.startBroadcastLoop()
return server
}
// Start 启动HTTP服务器
func (s *Server) Start() error {
mux := http.NewServeMux()
// API路由
if s.config.EnableAPI {
mux.HandleFunc("/api/stats", s.handleStats)
mux.HandleFunc("/api/shield", s.handleShield)
mux.HandleFunc("/api/shield/hosts", s.handleShieldHosts)
mux.HandleFunc("/api/shield/blacklists", s.handleShieldBlacklists)
mux.HandleFunc("/api/query", s.handleQuery)
mux.HandleFunc("/api/status", s.handleStatus)
mux.HandleFunc("/api/config", s.handleConfig)
// 添加统计相关接口
mux.HandleFunc("/api/top-blocked", s.handleTopBlockedDomains)
mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains)
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
mux.HandleFunc("/api/query/type", s.handleQueryTypeStats)
// WebSocket端点
mux.HandleFunc("/ws/stats", s.handleWebSocketStats)
}
// 静态文件服务(可后续添加前端界面)
mux.Handle("/", http.FileServer(http.Dir("./static")))
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
logger.Info(fmt.Sprintf("HTTP控制台服务器启动监听地址: %s:%d", s.config.Host, s.config.Port))
return s.server.ListenAndServe()
}
// Stop 停止HTTP服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Close()
}
logger.Info("HTTP控制台服务器已停止")
}
// handleStats 处理统计信息请求
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型(如果有)
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间为两位小数
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 构建响应数据,确保所有字段都反映服务器的真实状态
stats := map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"time": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// WebSocket相关方法
// handleWebSocketStats 处理WebSocket连接用于实时推送统计数据
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
// 升级HTTP连接为WebSocket
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
return
}
defer conn.Close()
// 将新客户端添加到客户端列表
s.clientsMutex.Lock()
s.clients[conn] = true
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("新WebSocket客户端连接当前连接数: %d", clientCount))
// 发送初始数据
if err := s.sendInitialStats(conn); err != nil {
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
return
}
// 定期发送更新数据
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
defer ticker.Stop()
// 最后一次发送的数据快照,用于检测变化
var lastStats map[string]interface{}
// 保持连接并定期发送数据
for {
select {
case <-ticker.C:
// 获取最新统计数据
currentStats := s.buildStatsData()
// 检查数据是否有变化
if !s.areStatsEqual(lastStats, currentStats) {
// 数据有变化,发送更新
data, err := json.Marshal(map[string]interface{}{
"type": "stats_update",
"data": currentStats,
"time": time.Now(),
})
if err != nil {
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
continue
}
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
return
}
// 更新最后发送的数据
lastStats = currentStats
}
case <-r.Context().Done():
// 客户端断开连接
s.clientsMutex.Lock()
delete(s.clients, conn)
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("WebSocket客户端断开连接当前连接数: %d", clientCount))
return
}
}
}
// sendInitialStats 发送初始统计数据
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
stats := s.buildStatsData()
data, err := json.Marshal(map[string]interface{}{
"type": "initial_data",
"data": stats,
"time": time.Now(),
})
if err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, data)
}
// buildStatsData 构建统计数据
func (s *Server) buildStatsData() map[string]interface{} {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
return map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
}
}
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
if stats1 == nil || stats2 == nil {
return false
}
// 只比较关键数值,避免频繁更新
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
// 检查主要计数器
if dns1["Queries"] != dns2["Queries"] ||
dns1["Blocked"] != dns2["Blocked"] ||
dns1["Allowed"] != dns2["Allowed"] ||
dns1["Errors"] != dns2["Errors"] {
return false
}
}
}
return true
}
// startBroadcastLoop 启动广播循环
func (s *Server) startBroadcastLoop() {
for message := range s.broadcastChan {
s.clientsMutex.Lock()
for client := range s.clients {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
client.Close()
delete(s.clients, client)
}
}
s.clientsMutex.Unlock()
}
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopResolvedDomains 处理获取最常解析的域名请求
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopResolvedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleRecentBlockedDomains 处理最近屏蔽域名请求
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetRecentBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"time": domain.LastSeen.Format("15:04:05"),
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleHourlyStats 处理24小时统计请求
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
hourlyStats := s.dnsServer.GetHourlyStats()
// 生成最近24小时的数据
now := time.Now()
labels := make([]string, 24)
data := make([]int64, 24)
for i := 23; i >= 0; i-- {
hour := now.Add(time.Duration(-i) * time.Hour)
hourKey := hour.Format("2006-01-02-15")
labels[23-i] = hour.Format("15:00")
data[23-i] = hourlyStats[hourKey]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleDailyStats 处理每日统计数据请求
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去7天的时间标签
labels := make([]string, 7)
data := make([]int64, 7)
now := time.Now()
for i := 6; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[6-i] = t.Format("01-02")
data[6-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleMonthlyStats 处理每月统计数据请求
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据用于30天视图
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去30天的时间标签
labels := make([]string, 30)
data := make([]int64, 30)
now := time.Now()
for i := 29; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[29-i] = t.Format("01-02")
data[29-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleQueryTypeStats 处理查询类型统计请求
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取DNS统计数据
dnsStats := s.dnsServer.GetStats()
// 转换为前端需要的格式
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
for queryType, count := range dnsStats.QueryTypes {
result = append(result, map[string]interface{}{
"type": queryType,
"count": count,
})
}
// 按计数降序排序
sort.Slice(result, func(i, j int) bool {
return result[i]["count"].(int64) > result[j]["count"].(int64)
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleShield 处理屏蔽规则管理请求
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表
switch r.Method {
case http.MethodGet:
// 获取规则统计信息
stats := s.shieldManager.GetStats()
shieldInfo := map[string]interface{}{
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": s.globalConfig.Shield.BlockMethod,
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
"domainRulesCount": stats["domainRules"],
"domainExceptionsCount": stats["domainExceptions"],
"regexRulesCount": stats["regexRules"],
"regexExceptionsCount": stats["regexExceptions"],
"hostsRulesCount": stats["hostsRules"],
}
json.NewEncoder(w).Encode(shieldInfo)
return
case http.MethodPost:
// 添加屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodDelete:
// 删除屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodPut:
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
return
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
}
// handleShieldBlacklists 处理远程黑名单管理请求
func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 处理更新单个黑名单
if strings.Contains(r.URL.Path, "/update") {
if r.Method == http.MethodPost {
// 提取黑名单URL或Name
parts := strings.Split(r.URL.Path, "/")
var targetURLOrName string
for i, part := range parts {
if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" {
targetURLOrName = parts[i+1]
break
}
}
if targetURLOrName == "" {
http.Error(w, "黑名单标识不能为空", http.StatusBadRequest)
return
}
// 获取黑名单列表
blacklists := s.shieldManager.GetBlacklists()
var targetIndex = -1
for i, list := range blacklists {
if list.URL == targetURLOrName || list.Name == targetURLOrName {
targetIndex = i
break
}
}
if targetIndex == -1 {
http.Error(w, "黑名单不存在", http.StatusNotFound)
return
}
// 更新时间戳
blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339)
// 保存更新后的黑名单列表
s.shieldManager.UpdateBlacklist(blacklists)
// 重新加载规则以获取最新的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
}
// 处理删除黑名单
parts := strings.Split(r.URL.Path, "/")
if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete {
id := parts[4]
blacklists := s.shieldManager.GetBlacklists()
var newBlacklists []config.BlacklistEntry
for _, list := range blacklists {
if list.URL != id && list.Name != id {
newBlacklists = append(newBlacklists, list)
}
}
s.shieldManager.UpdateBlacklist(newBlacklists)
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
switch r.Method {
case http.MethodGet:
// 获取远程黑名单列表
blacklists := s.shieldManager.GetBlacklists()
json.NewEncoder(w).Encode(blacklists)
case http.MethodPost:
// 添加远程黑名单
var req struct {
Name string `json:"name"`
URL string `json:"url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Name == "" || req.URL == "" {
http.Error(w, "Name and URL are required", http.StatusBadRequest)
return
}
// 获取现有黑名单
blacklists := s.shieldManager.GetBlacklists()
// 检查是否已存在
for _, list := range blacklists {
if list.URL == req.URL {
http.Error(w, "黑名单URL已存在", http.StatusConflict)
return
}
}
// 添加新黑名单
newEntry := config.BlacklistEntry{
Name: req.Name,
URL: req.URL,
Enabled: true,
}
blacklists = append(blacklists, newEntry)
s.shieldManager.UpdateBlacklist(blacklists)
// 重新加载规则以获取新添加的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodPut:
// 更新所有远程黑名单
blacklists := s.shieldManager.GetBlacklists()
for i := range blacklists {
// 更新每个黑名单的时间戳
blacklists[i].LastUpdateTime = time.Now().Format(time.RFC3339)
}
s.shieldManager.UpdateBlacklist(blacklists)
// 重新加载所有规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleShieldHosts 处理hosts管理请求
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
// 添加hosts条目
var req struct {
IP string `json:"ip"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" || req.Domain == "" {
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除hosts条目
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodGet:
// 获取hosts条目列表
hosts := s.shieldManager.GetAllHosts()
hostsCount := s.shieldManager.GetHostsCount()
// 转换为数组格式,便于前端展示
hostsList := make([]map[string]string, 0, len(hosts))
for domain, ip := range hosts {
hostsList = append(hostsList, map[string]string{
"domain": domain,
"ip": ip,
})
}
json.NewEncoder(w).Encode(map[string]interface{}{
"hosts": hostsList,
"hostsCount": hostsCount,
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleQuery 处理DNS查询请求
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
// 获取域名屏蔽的详细信息
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
// 添加时间戳
blockDetails["timestamp"] = time.Now()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(blockDetails)
}
// handleStatus 处理系统状态请求
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := s.dnsServer.GetStats()
// 使用服务器的实际启动时间计算准确的运行时间
serverStartTime := s.dnsServer.GetStartTime()
uptime := time.Since(serverStartTime)
// 构建包含所有真实服务器统计数据的响应
status := map[string]interface{}{
"status": "running",
"queries": stats.Queries,
"blocked": stats.Blocked,
"allowed": stats.Allowed,
"errors": stats.Errors,
"lastQuery": stats.LastQuery,
"avgResponseTime": stats.AvgResponseTime,
"activeIPs": len(stats.SourceIPs),
"startTime": serverStartTime,
"uptime": uptime,
"cpuUsage": stats.CpuUsage,
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
// saveConfigToFile 保存配置到文件
func saveConfigToFile(config *config.Config, filePath string) error {
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return ioutil.WriteFile(filePath, data, 0644)
}
// handleConfig 处理配置请求
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 返回当前配置(包括黑名单配置)
config := map[string]interface{}{
"shield": map[string]interface{}{
"blockMethod": s.globalConfig.Shield.BlockMethod,
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
"blacklists": s.globalConfig.Shield.Blacklists,
"updateInterval": s.globalConfig.Shield.UpdateInterval,
},
}
json.NewEncoder(w).Encode(config)
case http.MethodPost:
// 更新配置
var req struct {
Shield struct {
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
Blacklists []config.BlacklistEntry `json:"blacklists"`
UpdateInterval int `json:"updateInterval"`
} `json:"shield"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 更新屏蔽配置
if req.Shield.BlockMethod != "" {
// 验证屏蔽方法是否有效
validMethods := map[string]bool{
"NXDOMAIN": true,
"refused": true,
"emptyIP": true,
"customIP": true,
}
if !validMethods[req.Shield.BlockMethod] {
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
return
}
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
// 如果选择了customIP验证IP地址
if req.Shield.BlockMethod == "customIP" {
if req.Shield.CustomBlockIP == "" {
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
return
}
// 简单的IP地址验证
if !isValidIP(req.Shield.CustomBlockIP) {
http.Error(w, "无效的IP地址", http.StatusBadRequest)
return
}
}
}
if req.Shield.CustomBlockIP != "" {
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
}
// 更新黑名单配置
if req.Shield.Blacklists != nil {
// 验证黑名单配置
for i, bl := range req.Shield.Blacklists {
if bl.URL == "" {
http.Error(w, fmt.Sprintf("黑名单URL不能为空索引: %d", i), http.StatusBadRequest)
return
}
if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") {
http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest)
return
}
}
s.globalConfig.Shield.Blacklists = req.Shield.Blacklists
s.shieldManager.UpdateBlacklist(req.Shield.Blacklists)
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
logger.Error("重新加载规则失败", "error", err)
}
}
// 更新更新间隔
if req.Shield.UpdateInterval > 0 {
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
// 重新启动自动更新
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
}
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
logger.Error("保存配置到文件失败", "error", err)
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
}
// 返回成功响应
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "配置已更新",
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// isValidIP 简单验证IP地址格式
func isValidIP(ip string) bool {
// 简单的IPv4地址验证
parts := strings.Split(ip, ".")
if len(parts) != 4 {
return false
}
for _, part := range parts {
// 检查是否为数字
for _, char := range part {
if char < '0' || char > '9' {
return false
}
}
// 检查数字范围
var num int
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
return false
}
}
return true
}