增加websocket,数据实时显示

This commit is contained in:
Alex Yang
2025-11-26 01:11:37 +08:00
parent 63154085f7
commit 54dbb024e1
5 changed files with 447 additions and 3 deletions

View File

@@ -7,8 +7,10 @@ import (
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"dns-server/config"
"dns-server/dns"
"dns-server/logger"
@@ -22,16 +24,37 @@ type Server struct {
dnsServer *dns.Server
shieldManager *shield.ShieldManager
server *http.Server
// WebSocket相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
}
// NewServer 创建HTTP服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
return &Server{
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有CORS请求
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
broadcastChan: make(chan []byte, 100),
}
// 启动广播协程
go server.startBroadcastLoop()
return server
}
// Start 启动HTTP服务器
@@ -55,6 +78,8 @@ func (s *Server) Start() error {
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
mux.HandleFunc("/api/query/type", s.handleQueryTypeStats)
// WebSocket端点
mux.HandleFunc("/ws/stats", s.handleWebSocketStats)
}
// 静态文件服务(可后续添加前端界面)
@@ -133,6 +158,174 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(stats)
}
// WebSocket相关方法
// handleWebSocketStats 处理WebSocket连接用于实时推送统计数据
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
// 升级HTTP连接为WebSocket
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
return
}
defer conn.Close()
// 将新客户端添加到客户端列表
s.clientsMutex.Lock()
s.clients[conn] = true
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("新WebSocket客户端连接当前连接数: %d", clientCount))
// 发送初始数据
if err := s.sendInitialStats(conn); err != nil {
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
return
}
// 定期发送更新数据
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
defer ticker.Stop()
// 最后一次发送的数据快照,用于检测变化
var lastStats map[string]interface{}
// 保持连接并定期发送数据
for {
select {
case <-ticker.C:
// 获取最新统计数据
currentStats := s.buildStatsData()
// 检查数据是否有变化
if !s.areStatsEqual(lastStats, currentStats) {
// 数据有变化,发送更新
data, err := json.Marshal(map[string]interface{}{
"type": "stats_update",
"data": currentStats,
"time": time.Now(),
})
if err != nil {
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
continue
}
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
return
}
// 更新最后发送的数据
lastStats = currentStats
}
case <-r.Context().Done():
// 客户端断开连接
s.clientsMutex.Lock()
delete(s.clients, conn)
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("WebSocket客户端断开连接当前连接数: %d", clientCount))
return
}
}
}
// sendInitialStats 发送初始统计数据
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
stats := s.buildStatsData()
data, err := json.Marshal(map[string]interface{}{
"type": "initial_data",
"data": stats,
"time": time.Now(),
})
if err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, data)
}
// buildStatsData 构建统计数据
func (s *Server) buildStatsData() map[string]interface{} {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
return map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
}
}
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
if stats1 == nil || stats2 == nil {
return false
}
// 只比较关键数值,避免频繁更新
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
// 检查主要计数器
if dns1["Queries"] != dns2["Queries"] ||
dns1["Blocked"] != dns2["Blocked"] ||
dns1["Allowed"] != dns2["Allowed"] ||
dns1["Errors"] != dns2["Errors"] {
return false
}
}
}
return true
}
// startBroadcastLoop 启动广播循环
func (s *Server) startBroadcastLoop() {
for message := range s.broadcastChan {
s.clientsMutex.Lock()
for client := range s.clients {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
client.Close()
delete(s.clients, client)
}
}
s.clientsMutex.Unlock()
}
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {