更新Swagger API文档

This commit is contained in:
Alex Yang
2025-11-29 19:30:47 +08:00
parent aea162a616
commit ca876a8951
39 changed files with 16612 additions and 2003 deletions

View File

@@ -5,9 +5,12 @@ import (
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"dns-server/config"
"dns-server/dns"
"dns-server/logger"
@@ -21,16 +24,37 @@ type Server struct {
dnsServer *dns.Server
shieldManager *shield.ShieldManager
server *http.Server
// WebSocket相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
}
// NewServer 创建HTTP服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
return &Server{
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有CORS请求
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
broadcastChan: make(chan []byte, 100),
}
// 启动广播协程
go server.startBroadcastLoop()
return server
}
// Start 启动HTTP服务器
@@ -51,6 +75,11 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/top-resolved", s.handleTopResolvedDomains)
mux.HandleFunc("/api/recent-blocked", s.handleRecentBlockedDomains)
mux.HandleFunc("/api/hourly-stats", s.handleHourlyStats)
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
mux.HandleFunc("/api/query/type", s.handleQueryTypeStats)
// WebSocket端点
mux.HandleFunc("/ws/stats", s.handleWebSocketStats)
}
// 静态文件服务(可后续添加前端界面)
@@ -85,16 +114,218 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型(如果有)
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间为两位小数
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 构建响应数据,确保所有字段都反映服务器的真实状态
stats := map[string]interface{}{
"dns": dnsStats,
"shield": shieldStats,
"time": time.Now(),
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"time": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// WebSocket相关方法
// handleWebSocketStats 处理WebSocket连接用于实时推送统计数据
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
// 升级HTTP连接为WebSocket
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
return
}
defer conn.Close()
// 将新客户端添加到客户端列表
s.clientsMutex.Lock()
s.clients[conn] = true
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("新WebSocket客户端连接当前连接数: %d", clientCount))
// 发送初始数据
if err := s.sendInitialStats(conn); err != nil {
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
return
}
// 定期发送更新数据
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
defer ticker.Stop()
// 最后一次发送的数据快照,用于检测变化
var lastStats map[string]interface{}
// 保持连接并定期发送数据
for {
select {
case <-ticker.C:
// 获取最新统计数据
currentStats := s.buildStatsData()
// 检查数据是否有变化
if !s.areStatsEqual(lastStats, currentStats) {
// 数据有变化,发送更新
data, err := json.Marshal(map[string]interface{}{
"type": "stats_update",
"data": currentStats,
"time": time.Now(),
})
if err != nil {
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
continue
}
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
return
}
// 更新最后发送的数据
lastStats = currentStats
}
case <-r.Context().Done():
// 客户端断开连接
s.clientsMutex.Lock()
delete(s.clients, conn)
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("WebSocket客户端断开连接当前连接数: %d", clientCount))
return
}
}
}
// sendInitialStats 发送初始统计数据
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
stats := s.buildStatsData()
data, err := json.Marshal(map[string]interface{}{
"type": "initial_data",
"data": stats,
"time": time.Now(),
})
if err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, data)
}
// buildStatsData 构建统计数据
func (s *Server) buildStatsData() map[string]interface{} {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
return map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
}
}
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
if stats1 == nil || stats2 == nil {
return false
}
// 只比较关键数值,避免频繁更新
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
// 检查主要计数器
if dns1["Queries"] != dns2["Queries"] ||
dns1["Blocked"] != dns2["Blocked"] ||
dns1["Allowed"] != dns2["Allowed"] ||
dns1["Errors"] != dns2["Errors"] {
return false
}
}
}
return true
}
// startBroadcastLoop 启动广播循环
func (s *Server) startBroadcastLoop() {
for message := range s.broadcastChan {
s.clientsMutex.Lock()
for client := range s.clients {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
client.Close()
delete(s.clients, client)
}
}
s.clientsMutex.Unlock()
}
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
@@ -191,35 +422,117 @@ func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(result)
}
// handleDailyStats 处理每日统计数据请求
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去7天的时间标签
labels := make([]string, 7)
data := make([]int64, 7)
now := time.Now()
for i := 6; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[6-i] = t.Format("01-02")
data[6-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleMonthlyStats 处理每月统计数据请求
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据用于30天视图
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去30天的时间标签
labels := make([]string, 30)
data := make([]int64, 30)
now := time.Now()
for i := 29; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[29-i] = t.Format("01-02")
data[29-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleQueryTypeStats 处理查询类型统计请求
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取DNS统计数据
dnsStats := s.dnsServer.GetStats()
// 转换为前端需要的格式
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
for queryType, count := range dnsStats.QueryTypes {
result = append(result, map[string]interface{}{
"type": queryType,
"count": count,
})
}
// 按计数降序排序
sort.Slice(result, func(i, j int) bool {
return result[i]["count"].(int64) > result[j]["count"].(int64)
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleShield 处理屏蔽规则管理请求
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 返回屏蔽规则的基本配置信息
// 返回屏蔽规则的基本配置信息和统计数据,不返回完整规则列表
switch r.Method {
case http.MethodGet:
// 获取规则统计信息
stats := s.shieldManager.GetStats()
shieldInfo := map[string]interface{}{
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": s.globalConfig.Shield.BlockMethod,
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": s.globalConfig.Shield.BlockMethod,
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
"domainRulesCount": stats["domainRules"],
"domainExceptionsCount": stats["domainExceptions"],
"regexRulesCount": stats["regexRules"],
"regexExceptionsCount": stats["regexExceptions"],
"hostsRulesCount": stats["hostsRules"],
}
json.NewEncoder(w).Encode(shieldInfo)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
// 处理远程黑名单管理子路由
if strings.HasPrefix(r.URL.Path, "/shield/blacklists") {
s.handleShieldBlacklists(w, r)
return
}
switch r.Method {
case http.MethodGet:
// 获取完整规则列表
rules := s.shieldManager.GetRules()
json.NewEncoder(w).Encode(rules)
case http.MethodPost:
// 添加屏蔽规则
var req struct {
@@ -237,7 +550,7 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodDelete:
// 删除屏蔽规则
var req struct {
@@ -255,7 +568,7 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodPut:
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
@@ -263,9 +576,10 @@ func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
return
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
}
@@ -506,12 +820,25 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
}
stats := s.dnsServer.GetStats()
// 使用服务器的实际启动时间计算准确的运行时间
serverStartTime := s.dnsServer.GetStartTime()
uptime := time.Since(serverStartTime)
// 构建包含所有真实服务器统计数据的响应
status := map[string]interface{}{
"status": "running",
"queries": stats.Queries,
"lastQuery": stats.LastQuery,
"uptime": time.Since(stats.LastQuery),
"timestamp": time.Now(),
"status": "running",
"queries": stats.Queries,
"blocked": stats.Blocked,
"allowed": stats.Allowed,
"errors": stats.Errors,
"lastQuery": stats.LastQuery,
"avgResponseTime": stats.AvgResponseTime,
"activeIPs": len(stats.SourceIPs),
"startTime": serverStartTime,
"uptime": uptime,
"cpuUsage": stats.CpuUsage,
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")