增加websocket,数据实时显示
This commit is contained in:
195
http/server.go
195
http/server.go
@@ -7,8 +7,10 @@ import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"dns-server/config"
|
||||
"dns-server/dns"
|
||||
"dns-server/logger"
|
||||
@@ -22,16 +24,37 @@ type Server struct {
|
||||
dnsServer *dns.Server
|
||||
shieldManager *shield.ShieldManager
|
||||
server *http.Server
|
||||
|
||||
// WebSocket相关字段
|
||||
upgrader websocket.Upgrader
|
||||
clients map[*websocket.Conn]bool
|
||||
clientsMutex sync.Mutex
|
||||
broadcastChan chan []byte
|
||||
}
|
||||
|
||||
// NewServer 创建HTTP服务器实例
|
||||
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager) *Server {
|
||||
return &Server{
|
||||
server := &Server{
|
||||
globalConfig: globalConfig,
|
||||
config: &globalConfig.HTTP,
|
||||
dnsServer: dnsServer,
|
||||
shieldManager: shieldManager,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
// 允许所有CORS请求
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcastChan: make(chan []byte, 100),
|
||||
}
|
||||
|
||||
// 启动广播协程
|
||||
go server.startBroadcastLoop()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// Start 启动HTTP服务器
|
||||
@@ -55,6 +78,8 @@ func (s *Server) Start() error {
|
||||
mux.HandleFunc("/api/daily-stats", s.handleDailyStats)
|
||||
mux.HandleFunc("/api/monthly-stats", s.handleMonthlyStats)
|
||||
mux.HandleFunc("/api/query/type", s.handleQueryTypeStats)
|
||||
// WebSocket端点
|
||||
mux.HandleFunc("/ws/stats", s.handleWebSocketStats)
|
||||
}
|
||||
|
||||
// 静态文件服务(可后续添加前端界面)
|
||||
@@ -133,6 +158,174 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
|
||||
// WebSocket相关方法
|
||||
|
||||
// handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据
|
||||
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
|
||||
// 升级HTTP连接为WebSocket
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 将新客户端添加到客户端列表
|
||||
s.clientsMutex.Lock()
|
||||
s.clients[conn] = true
|
||||
clientCount := len(s.clients)
|
||||
s.clientsMutex.Unlock()
|
||||
|
||||
logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount))
|
||||
|
||||
// 发送初始数据
|
||||
if err := s.sendInitialStats(conn); err != nil {
|
||||
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 定期发送更新数据
|
||||
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
|
||||
defer ticker.Stop()
|
||||
|
||||
// 最后一次发送的数据快照,用于检测变化
|
||||
var lastStats map[string]interface{}
|
||||
|
||||
// 保持连接并定期发送数据
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 获取最新统计数据
|
||||
currentStats := s.buildStatsData()
|
||||
|
||||
// 检查数据是否有变化
|
||||
if !s.areStatsEqual(lastStats, currentStats) {
|
||||
// 数据有变化,发送更新
|
||||
data, err := json.Marshal(map[string]interface{}{
|
||||
"type": "stats_update",
|
||||
"data": currentStats,
|
||||
"time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新最后发送的数据
|
||||
lastStats = currentStats
|
||||
}
|
||||
case <-r.Context().Done():
|
||||
// 客户端断开连接
|
||||
s.clientsMutex.Lock()
|
||||
delete(s.clients, conn)
|
||||
clientCount := len(s.clients)
|
||||
s.clientsMutex.Unlock()
|
||||
logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendInitialStats 发送初始统计数据
|
||||
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
|
||||
stats := s.buildStatsData()
|
||||
data, err := json.Marshal(map[string]interface{}{
|
||||
"type": "initial_data",
|
||||
"data": stats,
|
||||
"time": time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// buildStatsData 构建统计数据
|
||||
func (s *Server) buildStatsData() map[string]interface{} {
|
||||
dnsStats := s.dnsServer.GetStats()
|
||||
shieldStats := s.shieldManager.GetStats()
|
||||
|
||||
// 获取最常用查询类型
|
||||
topQueryType := "-"
|
||||
maxCount := int64(0)
|
||||
if len(dnsStats.QueryTypes) > 0 {
|
||||
for queryType, count := range dnsStats.QueryTypes {
|
||||
if count > maxCount {
|
||||
maxCount = count
|
||||
topQueryType = queryType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃来源IP数量
|
||||
activeIPCount := len(dnsStats.SourceIPs)
|
||||
|
||||
// 格式化平均响应时间
|
||||
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
|
||||
|
||||
return map[string]interface{}{
|
||||
"dns": map[string]interface{}{
|
||||
"Queries": dnsStats.Queries,
|
||||
"Blocked": dnsStats.Blocked,
|
||||
"Allowed": dnsStats.Allowed,
|
||||
"Errors": dnsStats.Errors,
|
||||
"LastQuery": dnsStats.LastQuery,
|
||||
"AvgResponseTime": formattedResponseTime,
|
||||
"TotalResponseTime": dnsStats.TotalResponseTime,
|
||||
"QueryTypes": dnsStats.QueryTypes,
|
||||
"SourceIPs": dnsStats.SourceIPs,
|
||||
"CpuUsage": dnsStats.CpuUsage,
|
||||
},
|
||||
"shield": shieldStats,
|
||||
"topQueryType": topQueryType,
|
||||
"activeIPs": activeIPCount,
|
||||
"avgResponseTime": formattedResponseTime,
|
||||
"cpuUsage": dnsStats.CpuUsage,
|
||||
}
|
||||
}
|
||||
|
||||
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
|
||||
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
|
||||
if stats1 == nil || stats2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只比较关键数值,避免频繁更新
|
||||
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
|
||||
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
|
||||
// 检查主要计数器
|
||||
if dns1["Queries"] != dns2["Queries"] ||
|
||||
dns1["Blocked"] != dns2["Blocked"] ||
|
||||
dns1["Allowed"] != dns2["Allowed"] ||
|
||||
dns1["Errors"] != dns2["Errors"] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// startBroadcastLoop 启动广播循环
|
||||
func (s *Server) startBroadcastLoop() {
|
||||
for message := range s.broadcastChan {
|
||||
s.clientsMutex.Lock()
|
||||
for client := range s.clients {
|
||||
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
|
||||
client.Close()
|
||||
delete(s.clients, client)
|
||||
}
|
||||
}
|
||||
s.clientsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// handleTopBlockedDomains 处理TOP屏蔽域名请求
|
||||
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
|
||||
Reference in New Issue
Block a user