更新Swagger API文档
This commit is contained in:
389
http/server.go
389
http/server.go
@@ -5,9 +5,12 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"dns-server/config"
|
||||
"dns-server/dns"
|
||||
"dns-server/logger"
|
||||
@@ -21,16 +24,37 @@ type Server struct {
|
||||
dnsServer *dns.Server
|
||||
shieldManager *shield.ShieldManager
|
||||
server *http.Server
|
||||
|
||||
// WebSocket相关字段
|
||||
upgrader websocket.Upgrader
|
||||
clients map[*websocket.Conn]bool
|
||||
clientsMutex sync.Mutex
|
||||
broadcastChan chan []byte
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
|
||||
return &Server{
|
||||
server := &Server{
|
||||
globalConfig: globalConfig,
|
||||
config: &globalConfig.HTTP,
|
||||
dnsServer: dnsServer,
|
||||
shieldManager: shieldManager,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
// 允许所有CORS请求
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcastChan: make(chan []byte, 100),
|
||||
}
|
||||
|
||||
// 启动广播协程
|
||||
go server.startBroadcastLoop()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Start 启动HTTP服务器
|
||||
@@ -51,6 +75,11 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains)
|
||||
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
|
||||
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
|
||||
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
|
||||
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
|
||||
mux.HandleFunc("/api/query/type", s.handleQueryTypeStats)
|
||||
// WebSocket端点
|
||||
mux.HandleFunc("/ws/stats", s.handleWebSocketStats)
|
||||
}
|
||||
|
||||
// 静态文件服务(可后续添加前端界面)
|
||||
@@ -85,16 +114,218 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
dnsStats := s.dnsServer.GetStats()
|
||||
shieldStats := s.shieldManager.GetStats()
|
||||
|
||||
// 获取最常用查询类型(如果有)
|
||||
topQueryType := "-"
|
||||
maxCount := int64(0)
|
||||
if len(dnsStats.QueryTypes) > 0 {
|
||||
for queryType, count := range dnsStats.QueryTypes {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
topQueryType = queryType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃来源IP数量
|
||||
activeIPCount := len(dnsStats.SourceIPs)
|
||||
|
||||
// 格式化平均响应时间为两位小数
|
||||
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
|
||||
|
||||
// 构建响应数据,确保所有字段都反映服务器的真实状态
|
||||
stats := map[string]interface{}{
|
||||
"dns": dnsStats,
|
||||
"shield": shieldStats,
|
||||
"time": time.Now(),
|
||||
"dns": map[string]interface{}{
|
||||
"Queries": dnsStats.Queries,
|
||||
"Blocked": dnsStats.Blocked,
|
||||
"Allowed": dnsStats.Allowed,
|
||||
"Errors": dnsStats.Errors,
|
||||
"LastQuery": dnsStats.LastQuery,
|
||||
"AvgResponseTime": formattedResponseTime,
|
||||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||||
"QueryTypes": dnsStats.QueryTypes,
|
||||
"SourceIPs": dnsStats.SourceIPs,
|
||||
"CpuUsage": dnsStats.CpuUsage,
|
||||
},
|
||||
"shield": shieldStats,
|
||||
"topQueryType": topQueryType,
|
||||
"activeIPs": activeIPCount,
|
||||
"avgResponseTime": formattedResponseTime,
|
||||
"cpuUsage": dnsStats.CpuUsage,
|
||||
"time": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
|
||||
// WebSocket相关方法
|
||||
|
||||
// handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据
|
||||
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
|
||||
// 升级HTTP连接为WebSocket
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 将新客户端添加到客户端列表
|
||||
s.clientsMutex.Lock()
|
||||
s.clients[conn] = true
|
||||
clientCount := len(s.clients)
|
||||
s.clientsMutex.Unlock()
|
||||
|
||||
logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount))
|
||||
|
||||
// 发送初始数据
|
||||
if err := s.sendInitialStats(conn); err != nil {
|
||||
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 定期发送更新数据
|
||||
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
|
||||
defer ticker.Stop()
|
||||
|
||||
// 最后一次发送的数据快照,用于检测变化
|
||||
var lastStats map[string]interface{}
|
||||
|
||||
// 保持连接并定期发送数据
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 获取最新统计数据
|
||||
currentStats := s.buildStatsData()
|
||||
|
||||
// 检查数据是否有变化
|
||||
if !s.areStatsEqual(lastStats, currentStats) {
|
||||
// 数据有变化,发送更新
|
||||
data, err := json.Marshal(map[string]interface{}{
|
||||
"type": "stats_update",
|
||||
"data": currentStats,
|
||||
"time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新最后发送的数据
|
||||
lastStats = currentStats
|
||||
}
|
||||
case <-r.Context().Done():
|
||||
// 客户端断开连接
|
||||
s.clientsMutex.Lock()
|
||||
delete(s.clients, conn)
|
||||
clientCount := len(s.clients)
|
||||
s.clientsMutex.Unlock()
|
||||
logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendInitialStats 发送初始统计数据
|
||||
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
|
||||
stats := s.buildStatsData()
|
||||
data, err := json.Marshal(map[string]interface{}{
|
||||
"type": "initial_data",
|
||||
"data": stats,
|
||||
"time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// buildStatsData 构建统计数据
|
||||
func (s *Server) buildStatsData() map[string]interface{} {
|
||||
dnsStats := s.dnsServer.GetStats()
|
||||
shieldStats := s.shieldManager.GetStats()
|
||||
|
||||
// 获取最常用查询类型
|
||||
topQueryType := "-"
|
||||
maxCount := int64(0)
|
||||
if len(dnsStats.QueryTypes) > 0 {
|
||||
for queryType, count := range dnsStats.QueryTypes {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
topQueryType = queryType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃来源IP数量
|
||||
activeIPCount := len(dnsStats.SourceIPs)
|
||||
|
||||
// 格式化平均响应时间
|
||||
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
|
||||
|
||||
return map[string]interface{}{
|
||||
"dns": map[string]interface{}{
|
||||
"Queries": dnsStats.Queries,
|
||||
"Blocked": dnsStats.Blocked,
|
||||
"Allowed": dnsStats.Allowed,
|
||||
"Errors": dnsStats.Errors,
|
||||
"LastQuery": dnsStats.LastQuery,
|
||||
"AvgResponseTime": formattedResponseTime,
|
||||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||||
"QueryTypes": dnsStats.QueryTypes,
|
||||
"SourceIPs": dnsStats.SourceIPs,
|
||||
"CpuUsage": dnsStats.CpuUsage,
|
||||
},
|
||||
"shield": shieldStats,
|
||||
"topQueryType": topQueryType,
|
||||
"activeIPs": activeIPCount,
|
||||
"avgResponseTime": formattedResponseTime,
|
||||
"cpuUsage": dnsStats.CpuUsage,
|
||||
}
|
||||
}
|
||||
|
||||
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
|
||||
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
|
||||
if stats1 == nil || stats2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只比较关键数值,避免频繁更新
|
||||
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
|
||||
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
|
||||
// 检查主要计数器
|
||||
if dns1["Queries"] != dns2["Queries"] ||
|
||||
dns1["Blocked"] != dns2["Blocked"] ||
|
||||
dns1["Allowed"] != dns2["Allowed"] ||
|
||||
dns1["Errors"] != dns2["Errors"] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// startBroadcastLoop 启动广播循环
|
||||
func (s *Server) startBroadcastLoop() {
|
||||
for message := range s.broadcastChan {
|
||||
s.clientsMutex.Lock()
|
||||
for client := range s.clients {
|
||||
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
|
||||
client.Close()
|
||||
delete(s.clients, client)
|
||||
}
|
||||
}
|
||||
s.clientsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// handleTopBlockedDomains 处理TOP屏蔽域名请求
|
||||
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
@@ -191,35 +422,117 @@ func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleDailyStats 处理每日统计数据请求
|
||||
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取每日统计数据
|
||||
dailyStats := s.dnsServer.GetDailyStats()
|
||||
|
||||
// 生成过去7天的时间标签
|
||||
labels := make([]string, 7)
|
||||
data := make([]int64, 7)
|
||||
now := time.Now()
|
||||
|
||||
for i := 6; i >= 0; i-- {
|
||||
t := now.AddDate(0, 0, -i)
|
||||
key := t.Format("2006-01-02")
|
||||
labels[6-i] = t.Format("01-02")
|
||||
data[6-i] = dailyStats[key]
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"labels": labels,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleMonthlyStats 处理每月统计数据请求
|
||||
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取每日统计数据(用于30天视图)
|
||||
dailyStats := s.dnsServer.GetDailyStats()
|
||||
|
||||
// 生成过去30天的时间标签
|
||||
labels := make([]string, 30)
|
||||
data := make([]int64, 30)
|
||||
now := time.Now()
|
||||
|
||||
for i := 29; i >= 0; i-- {
|
||||
t := now.AddDate(0, 0, -i)
|
||||
key := t.Format("2006-01-02")
|
||||
labels[29-i] = t.Format("01-02")
|
||||
data[29-i] = dailyStats[key]
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"labels": labels,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleQueryTypeStats 处理查询类型统计请求
|
||||
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取DNS统计数据
|
||||
dnsStats := s.dnsServer.GetStats()
|
||||
|
||||
// 转换为前端需要的格式
|
||||
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
|
||||
for queryType, count := range dnsStats.QueryTypes {
|
||||
result = append(result, map[string]interface{}{
|
||||
"type": queryType,
|
||||
"count": count,
|
||||
})
|
||||
}
|
||||
|
||||
// 按计数降序排序
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i]["count"].(int64) > result[j]["count"].(int64)
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
|
||||
// handleShield 处理屏蔽规则管理请求
|
||||
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// 返回屏蔽规则的基本配置信息
|
||||
// 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 获取规则统计信息
|
||||
stats := s.shieldManager.GetStats()
|
||||
shieldInfo := map[string]interface{}{
|
||||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||||
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
|
||||
"updateInterval": s.globalConfig.Shield.UpdateInterval,
|
||||
"blockMethod": s.globalConfig.Shield.BlockMethod,
|
||||
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
|
||||
"domainRulesCount": stats["domainRules"],
|
||||
"domainExceptionsCount": stats["domainExceptions"],
|
||||
"regexRulesCount": stats["regexRules"],
|
||||
"regexExceptionsCount": stats["regexExceptions"],
|
||||
"hostsRulesCount": stats["hostsRules"],
|
||||
}
|
||||
json.NewEncoder(w).Encode(shieldInfo)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// 处理远程黑名单管理子路由
|
||||
if strings.HasPrefix(r.URL.Path, "/shield/blacklists") {
|
||||
s.handleShieldBlacklists(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// 获取完整规则列表
|
||||
rules := s.shieldManager.GetRules()
|
||||
json.NewEncoder(w).Encode(rules)
|
||||
|
||||
case http.MethodPost:
|
||||
// 添加屏蔽规则
|
||||
var req struct {
|
||||
@@ -237,7 +550,7 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
return
|
||||
case http.MethodDelete:
|
||||
// 删除屏蔽规则
|
||||
var req struct {
|
||||
@@ -255,7 +568,7 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
|
||||
|
||||
return
|
||||
case http.MethodPut:
|
||||
// 重新加载规则
|
||||
if err := s.shieldManager.LoadRules(); err != nil {
|
||||
@@ -263,9 +576,10 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
|
||||
|
||||
return
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,12 +820,25 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
stats := s.dnsServer.GetStats()
|
||||
|
||||
// 使用服务器的实际启动时间计算准确的运行时间
|
||||
serverStartTime := s.dnsServer.GetStartTime()
|
||||
uptime := time.Since(serverStartTime)
|
||||
|
||||
// 构建包含所有真实服务器统计数据的响应
|
||||
status := map[string]interface{}{
|
||||
"status": "running",
|
||||
"queries": stats.Queries,
|
||||
"lastQuery": stats.LastQuery,
|
||||
"uptime": time.Since(stats.LastQuery),
|
||||
"timestamp": time.Now(),
|
||||
"status": "running",
|
||||
"queries": stats.Queries,
|
||||
"blocked": stats.Blocked,
|
||||
"allowed": stats.Allowed,
|
||||
"errors": stats.Errors,
|
||||
"lastQuery": stats.LastQuery,
|
||||
"avgResponseTime": stats.AvgResponseTime,
|
||||
"activeIPs": len(stats.SourceIPs),
|
||||
"startTime": serverStartTime,
|
||||
"uptime": uptime,
|
||||
"cpuUsage": stats.CpuUsage,
|
||||
"timestamp": time.Now(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
Reference in New Issue
Block a user