Files
dns-server/dns/http/server.go
T
Alex Yang cdac4fcf43 update
2026-01-16 11:09:11 +08:00

1671 lines
50 KiB
Go

package http
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/dns"
"dns-server/gfw"
"dns-server/logger"
"dns-server/shield"
"github.com/gorilla/websocket"
)
// Server HTTP控制台服务器
type Server struct {
globalConfig *config.Config
config *config.HTTPConfig
dnsServer *dns.Server
shieldManager *shield.ShieldManager
gfwManager *gfw.GFWListManager
server *http.Server
// 会话管理相关字段
sessions map[string]time.Time // 会话ID到过期时间的映射
sessionsMutex sync.Mutex // 会话映射的互斥锁
sessionTTL time.Duration // 会话过期时间
// WebSocket相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
}
// NewServer 创建HTTP服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
gfwManager: gfwManager,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有CORS请求
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
broadcastChan: make(chan []byte, 100),
// 会话管理初始化
sessions: make(map[string]time.Time),
sessionTTL: 24 * time.Hour, // 会话有效期24小时
}
// 启动广播协程
go server.startBroadcastLoop()
// 启动会话清理协程
go server.cleanupSessionsLoop()
return server
}
// Start 启动HTTP服务器
func (s *Server) Start() error {
mux := http.NewServeMux()
// 登录路由,不需要认证
mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
// 重定向到登录页面HTML
http.Redirect(w, r, "/login.html", http.StatusFound)
})
// API路由
if s.config.EnableAPI {
// 登录API端点,不需要认证
mux.HandleFunc("/api/login", s.handleLogin)
// 注销API端点,不需要认证
mux.HandleFunc("/api/logout", s.handleLogout)
// 修改密码API端点,需要认证
mux.HandleFunc("/api/change-password", s.loginRequired(s.handleChangePassword))
// 重定向/api到Swagger UI页面
mux.HandleFunc("/api", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/api/index.html", http.StatusMovedPermanently)
}))
// 注册所有API端点,应用登录中间件
mux.HandleFunc("/api/stats", s.loginRequired(s.handleStats))
mux.HandleFunc("/api/shield", s.loginRequired(s.handleShield))
mux.HandleFunc("/api/shield/localrules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
localRules := s.shieldManager.GetLocalRules()
json.NewEncoder(w).Encode(localRules)
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}))
mux.HandleFunc("/api/shield/remoterules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
remoteRules := s.shieldManager.GetRemoteRules()
json.NewEncoder(w).Encode(remoteRules)
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}))
mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts))
mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists))
mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery))
mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus))
mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig))
mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart))
// 添加统计相关接口
mux.HandleFunc("/api/top-blocked", s.loginRequired(s.handleTopBlockedDomains))
mux.HandleFunc("/api/top-resolved", s.loginRequired(s.handleTopResolvedDomains))
mux.HandleFunc("/api/top-clients", s.loginRequired(s.handleTopClients))
mux.HandleFunc("/api/top-domains", s.loginRequired(s.handleTopDomains))
mux.HandleFunc("/api/recent-blocked", s.loginRequired(s.handleRecentBlockedDomains))
mux.HandleFunc("/api/hourly-stats", s.loginRequired(s.handleHourlyStats))
mux.HandleFunc("/api/daily-stats", s.loginRequired(s.handleDailyStats))
mux.HandleFunc("/api/monthly-stats", s.loginRequired(s.handleMonthlyStats))
mux.HandleFunc("/api/query/type", s.loginRequired(s.handleQueryTypeStats))
// 日志统计相关接口
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
// WebSocket端点
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
// 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点
apiFileServer := http.FileServer(http.Dir("./static/api"))
mux.Handle("/api/", s.loginRequired(http.StripPrefix("/api", apiFileServer).ServeHTTP))
}
// 自定义静态文件服务处理器,用于禁用浏览器缓存,放在API路由之后
fileServer := http.FileServer(http.Dir("./static"))
// 单独处理login.html,不需要登录
mux.HandleFunc("/login.html", func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 直接提供login.html文件
http.ServeFile(w, r, "./static/login.html")
})
// Tracker目录静态文件服务
trackerFileServer := http.FileServer(http.Dir("./tracker"))
mux.HandleFunc("/tracker/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 使用StripPrefix处理路径
http.StripPrefix("/tracker", trackerFileServer).ServeHTTP(w, r)
}))
// 其他静态文件需要登录
mux.HandleFunc("/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 使用StripPrefix处理路径
http.StripPrefix("/", fileServer).ServeHTTP(w, r)
}))
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port))
return s.server.ListenAndServe()
}
// Stop 停止HTTP服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Close()
}
logger.Info("HTTP控制台服务器已停止")
}
// handleStats 处理统计信息请求
// @Summary 获取系统统计信息
// @Description 获取DNS服务器和Shield的统计信息
// @Tags stats
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "统计信息"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/stats [get]
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型(如果有)
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间为两位小数
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 计算DNSSEC使用率
dnssecUsage := float64(0)
if dnsStats.Queries > 0 {
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
}
// 构建响应数据,确保所有字段都反映服务器的真实状态
stats := map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
"DNSSECQueries": dnsStats.DNSSECQueries,
"DNSSECSuccess": dnsStats.DNSSECSuccess,
"DNSSECFailed": dnsStats.DNSSECFailed,
"DNSSECEnabled": dnsStats.DNSSECEnabled,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"dnssecEnabled": dnsStats.DNSSECEnabled,
"dnssecQueries": dnsStats.DNSSECQueries,
"dnssecSuccess": dnsStats.DNSSECSuccess,
"dnssecFailed": dnsStats.DNSSECFailed,
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
"time": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// WebSocket相关方法
// handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
// 升级HTTP连接为WebSocket
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
return
}
defer conn.Close()
// 将新客户端添加到客户端列表
s.clientsMutex.Lock()
s.clients[conn] = true
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount))
// 发送初始数据
if err := s.sendInitialStats(conn); err != nil {
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
return
}
// 定期发送更新数据
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
defer ticker.Stop()
// 最后一次发送的数据快照,用于检测变化
var lastStats map[string]interface{}
// 保持连接并定期发送数据
for {
select {
case <-ticker.C:
// 获取最新统计数据
currentStats := s.buildStatsData()
// 检查数据是否有变化
if !s.areStatsEqual(lastStats, currentStats) {
// 数据有变化,发送更新
data, err := json.Marshal(map[string]interface{}{
"type": "stats_update",
"data": currentStats,
"time": time.Now(),
})
if err != nil {
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
continue
}
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
return
}
// 更新最后发送的数据
lastStats = currentStats
}
case <-r.Context().Done():
// 客户端断开连接
s.clientsMutex.Lock()
delete(s.clients, conn)
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount))
return
}
}
}
// sendInitialStats 发送初始统计数据
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
stats := s.buildStatsData()
data, err := json.Marshal(map[string]interface{}{
"type": "initial_data",
"data": stats,
"time": time.Now(),
})
if err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, data)
}
// buildStatsData 构建统计数据
func (s *Server) buildStatsData() map[string]interface{} {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 计算DNSSEC使用率
dnssecUsage := float64(0)
if dnsStats.Queries > 0 {
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
}
return map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
"DNSSECQueries": dnsStats.DNSSECQueries,
"DNSSECSuccess": dnsStats.DNSSECSuccess,
"DNSSECFailed": dnsStats.DNSSECFailed,
"DNSSECEnabled": dnsStats.DNSSECEnabled,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"dnssecEnabled": dnsStats.DNSSECEnabled,
"dnssecQueries": dnsStats.DNSSECQueries,
"dnssecSuccess": dnsStats.DNSSECSuccess,
"dnssecFailed": dnsStats.DNSSECFailed,
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
}
}
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
if stats1 == nil || stats2 == nil {
return false
}
// 只比较关键数值,避免频繁更新
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
// 检查主要计数器
if dns1["Queries"] != dns2["Queries"] ||
dns1["Blocked"] != dns2["Blocked"] ||
dns1["Allowed"] != dns2["Allowed"] ||
dns1["Errors"] != dns2["Errors"] {
return false
}
}
}
return true
}
// startBroadcastLoop 启动广播循环
func (s *Server) startBroadcastLoop() {
for message := range s.broadcastChan {
s.clientsMutex.Lock()
for client := range s.clients {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
client.Close()
delete(s.clients, client)
}
}
s.clientsMutex.Unlock()
}
}
// cleanupSessionsLoop 定期清理过期会话
func (s *Server) cleanupSessionsLoop() {
for {
time.Sleep(1 * time.Hour) // 每小时清理一次
s.sessionsMutex.Lock()
now := time.Now()
for sessionID, expiryTime := range s.sessions {
if now.After(expiryTime) {
delete(s.sessions, sessionID)
}
}
s.sessionsMutex.Unlock()
}
}
// isAuthenticated 检查用户是否已认证
func (s *Server) isAuthenticated(r *http.Request) bool {
// 从Cookie中获取会话ID
cookie, err := r.Cookie("session_id")
if err != nil {
return false
}
sessionID := cookie.Value
s.sessionsMutex.Lock()
defer s.sessionsMutex.Unlock()
// 检查会话是否存在且未过期
expiryTime, exists := s.sessions[sessionID]
if !exists {
return false
}
if time.Now().After(expiryTime) {
// 会话已过期,删除它
delete(s.sessions, sessionID)
return false
}
// 延长会话有效期
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
return true
}
// loginRequired 登录中间件
func (s *Server) loginRequired(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 检查是否为登录页面或登录API,允许直接访问
if r.URL.Path == "/login" || r.URL.Path == "/api/login" {
next.ServeHTTP(w, r)
return
}
// 检查是否已认证
if !s.isAuthenticated(r) {
// 如果是API请求,返回401错误
if strings.HasPrefix(r.URL.Path, "/api/") || strings.HasPrefix(r.URL.Path, "/ws/") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "未授权访问"})
return
}
// 否则重定向到登录页面
http.Redirect(w, r, "/login", http.StatusFound)
return
}
// 已认证,继续处理请求
next.ServeHTTP(w, r)
}
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 返回最近30天的所有域名,设置合理上限50
domains := s.dnsServer.GetTopBlockedDomains(50)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopResolvedDomains 处理获取最常解析的域名请求
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopResolvedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleRecentBlockedDomains 处理最近屏蔽域名请求
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetRecentBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"time": time.Unix(domain.LastSeen, 0).Format("15:04:05"),
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleHourlyStats 处理24小时统计请求
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
hourlyStats := s.dnsServer.GetHourlyStats()
// 生成最近24小时的数据
now := time.Now()
labels := make([]string, 24)
data := make([]int64, 24)
for i := 23; i >= 0; i-- {
hour := now.Add(time.Duration(-i) * time.Hour)
hourKey := hour.Format("2006-01-02-15")
labels[23-i] = hour.Format("15:00")
data[23-i] = hourlyStats[hourKey]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleDailyStats 处理每日统计数据请求
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去7天的时间标签
labels := make([]string, 7)
data := make([]int64, 7)
now := time.Now()
for i := 6; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[6-i] = t.Format("01-02")
data[6-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleMonthlyStats 处理每月统计数据请求
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据(用于30天视图)
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去30天的时间标签
labels := make([]string, 30)
data := make([]int64, 30)
now := time.Now()
for i := 29; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[29-i] = t.Format("01-02")
data[29-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleQueryTypeStats 处理查询类型统计请求
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取DNS统计数据
dnsStats := s.dnsServer.GetStats()
// 转换为前端需要的格式
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
for queryType, count := range dnsStats.QueryTypes {
result = append(result, map[string]interface{}{
"type": queryType,
"count": count,
})
}
// 按计数降序排序
sort.Slice(result, func(i, j int) bool {
return result[i]["count"].(int64) > result[j]["count"].(int64)
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopClients 处理TOP客户端请求
func (s *Server) handleTopClients(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取TOP客户端列表
clients := s.dnsServer.GetTopClients(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(clients))
for i, client := range clients {
result[i] = map[string]interface{}{
"ip": client.IP,
"count": client.Count,
"lastSeen": client.LastSeen,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopDomains 处理TOP域名请求
func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取TOP被屏蔽域名,返回最近30天的数据,设置合理上限50
blockedDomains := s.dnsServer.GetTopBlockedDomains(50)
// 获取TOP已解析域名,返回最近30天的数据,设置合理上限50
resolvedDomains := s.dnsServer.GetTopResolvedDomains(50)
// 合并并去重域名统计
domainMap := make(map[string]int64)
dnssecStatusMap := make(map[string]bool)
for _, domain := range blockedDomains {
domainMap[domain.Domain] += domain.Count
dnssecStatusMap[domain.Domain] = domain.DNSSEC
}
for _, domain := range resolvedDomains {
domainMap[domain.Domain] += domain.Count
dnssecStatusMap[domain.Domain] = domain.DNSSEC
}
// 转换为切片并排序
domainList := make([]map[string]interface{}, 0, len(domainMap))
for domain, count := range domainMap {
dnssec, hasDNSSEC := dnssecStatusMap[domain]
domainList = append(domainList, map[string]interface{}{
"domain": domain,
"count": count,
"dnssec": hasDNSSEC && dnssec,
})
}
// 按计数降序排序
sort.Slice(domainList, func(i, j int) bool {
return domainList[i]["count"].(int64) > domainList[j]["count"].(int64)
})
// 返回所有合并后的域名,设置合理上限50
if len(domainList) > 50 {
domainList = domainList[:50]
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(domainList)
}
// handleShield 处理Shield相关操作
// @Summary 管理Shield配置
// @Description 获取或更新Shield的配置信息
// @Tags shield
// @Accept json
// @Produce json
// @Param config body map[string]interface{} false "Shield配置信息"
// @Success 200 {object} map[string]interface{} "配置信息"
// @Failure 400 {object} map[string]string "请求参数错误"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/shield [get]
// @Router /api/shield [post]
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 默认处理逻辑
switch r.Method {
case http.MethodGet:
// 检查是否需要返回完整规则列表
if r.URL.Query().Get("all") == "true" {
// 返回完整规则数据
rules := s.shieldManager.GetRules()
json.NewEncoder(w).Encode(rules)
return
}
// 获取规则统计信息
stats := s.shieldManager.GetStats()
shieldInfo := map[string]interface{}{
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": s.globalConfig.Shield.BlockMethod,
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
"domainRulesCount": stats["domainRules"],
"domainExceptionsCount": stats["domainExceptions"],
"regexRulesCount": stats["regexRules"],
"regexExceptionsCount": stats["regexExceptions"],
"hostsRulesCount": stats["hostsRules"],
}
json.NewEncoder(w).Encode(shieldInfo)
return
}
switch r.Method {
case http.MethodPost:
// 添加屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 清空DNS缓存,使新规则立即生效
s.dnsServer.DnsCache.Clear()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodDelete:
// 删除屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 清空DNS缓存,使规则变更立即生效
s.dnsServer.DnsCache.Clear()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodPut:
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
return
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
}
// handleShieldBlacklists 处理黑名单相关操作
// @Summary 管理黑名单
// @Description 处理黑名单的CRUD操作,包括获取列表、添加、更新和删除黑名单
// @Tags shield
// @Accept json
// @Produce json
// @Param name path string false "黑名单名称(用于删除操作)"
// @Param blacklist body map[string]interface{} false "黑名单信息(用于添加/更新操作)"
// @Success 200 {object} map[string]interface{} "操作成功"
// @Failure 400 {object} map[string]string "请求参数错误"
// @Failure 404 {object} map[string]string "黑名单不存在"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/shield/blacklists [get]
// @Router /api/shield/blacklists [post]
// @Router /api/shield/blacklists [put]
// @Router /api/shield/blacklists/{name} [delete]
func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 处理更新单个黑名单
if strings.Contains(r.URL.Path, "/update") {
if r.Method == http.MethodPost {
// 提取黑名单URL或Name
parts := strings.Split(r.URL.Path, "/")
var targetURLOrName string
for i, part := range parts {
if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" {
targetURLOrName = parts[i+1]
break
}
}
if targetURLOrName == "" {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单标识不能为空"})
return
}
// 获取黑名单列表
blacklists := s.shieldManager.GetBlacklists()
var targetIndex = -1
for i, list := range blacklists {
if list.URL == targetURLOrName || list.Name == targetURLOrName {
targetIndex = i
break
}
}
if targetIndex == -1 {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单不存在"})
return
}
// 更新时间戳
blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339)
// 保存更新后的黑名单列表
s.shieldManager.UpdateBlacklist(blacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则以获取最新的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
}
// 处理删除黑名单
parts := strings.Split(r.URL.Path, "/")
if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete {
id := parts[4]
blacklists := s.shieldManager.GetBlacklists()
var newBlacklists []config.BlacklistEntry
for _, list := range blacklists {
if list.URL != id && list.Name != id {
newBlacklists = append(newBlacklists, list)
}
}
s.shieldManager.UpdateBlacklist(newBlacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
switch r.Method {
case http.MethodGet:
// 获取远程黑名单列表
blacklists := s.shieldManager.GetBlacklists()
json.NewEncoder(w).Encode(blacklists)
case http.MethodPost:
// 添加远程黑名单
var req struct {
Name string `json:"name"`
URL string `json:"url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
return
}
if req.Name == "" || req.URL == "" {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "名称和URL不能为空"})
return
}
// 获取现有黑名单
blacklists := s.shieldManager.GetBlacklists()
// 检查是否已存在
for _, list := range blacklists {
if list.URL == req.URL {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单URL已存在"})
return
}
}
// 检查URL是否存在且可访问
if !checkURLExists(req.URL) {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "URL不存在或无法访问"})
return
}
// 添加新黑名单
newEntry := config.BlacklistEntry{
Name: req.Name,
URL: req.URL,
Enabled: true,
}
blacklists = append(blacklists, newEntry)
s.shieldManager.UpdateBlacklist(blacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则以获取新添加的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodPut:
// 更新黑名单列表(包括启用/禁用状态)
var updatedBlacklists []struct {
Name string `json:"Name" json:"name"`
URL string `json:"URL" json:"url"`
Enabled bool `json:"Enabled" json:"enabled"`
LastUpdateTime string `json:"LastUpdateTime" json:"lastUpdateTime"`
}
if err := json.NewDecoder(r.Body).Decode(&updatedBlacklists); err != nil {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
return
}
// 转换为config.BlacklistEntry类型
var newBlacklists []config.BlacklistEntry
for _, entry := range updatedBlacklists {
newBlacklists = append(newBlacklists, config.BlacklistEntry{
Name: entry.Name,
URL: entry.URL,
Enabled: entry.Enabled,
LastUpdateTime: entry.LastUpdateTime,
})
}
// 更新黑名单
s.shieldManager.UpdateBlacklist(newBlacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
default:
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "Method not allowed"})
}
}
// handleShieldHosts 处理hosts管理请求
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
// 添加hosts条目
var req struct {
IP string `json:"ip"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" || req.Domain == "" {
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除hosts条目
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodGet:
// 获取hosts条目列表
hosts := s.shieldManager.GetAllHosts()
hostsCount := s.shieldManager.GetHostsCount()
// 转换为数组格式,便于前端展示
hostsList := make([]map[string]string, 0, len(hosts))
for domain, ip := range hosts {
hostsList = append(hostsList, map[string]string{
"domain": domain,
"ip": ip,
})
}
json.NewEncoder(w).Encode(map[string]interface{}{
"hosts": hostsList,
"hostsCount": hostsCount,
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleQuery 处理DNS查询请求
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domain := r.URL.Query().Get("domain")
if domain == "" {
http.Error(w, "Domain parameter is required", http.StatusBadRequest)
return
}
// 获取域名屏蔽的详细信息
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
// 添加时间戳
blockDetails["timestamp"] = time.Now()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(blockDetails)
}
// handleStatus 处理系统状态请求
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := s.dnsServer.GetStats()
// 使用服务器的实际启动时间计算准确的运行时间
serverStartTime := s.dnsServer.GetStartTime()
uptime := time.Since(serverStartTime)
// 构建包含所有真实服务器统计数据的响应
status := map[string]interface{}{
"status": "running",
"queries": stats.Queries,
"blocked": stats.Blocked,
"allowed": stats.Allowed,
"errors": stats.Errors,
"lastQuery": stats.LastQuery,
"avgResponseTime": stats.AvgResponseTime,
"activeIPs": len(stats.SourceIPs),
"startTime": serverStartTime,
"uptime": uptime.Milliseconds(), // 转换为毫秒数,方便前端处理
"cpuUsage": stats.CpuUsage,
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
// saveConfigToFile 保存配置到文件
func saveConfigToFile(config *config.Config, filePath string) error {
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return ioutil.WriteFile(filePath, data, 0644)
}
// handleConfig 处理配置请求
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 返回当前配置(包括黑名单配置)
config := map[string]interface{}{
"Shield": map[string]interface{}{
"blockMethod": s.globalConfig.Shield.BlockMethod,
"customBlockIP": s.globalConfig.Shield.CustomBlockIP,
"blacklists": s.globalConfig.Shield.Blacklists,
"updateInterval": s.globalConfig.Shield.UpdateInterval,
},
"GFWList": map[string]interface{}{
"ip": s.globalConfig.GFWList.IP,
"content": s.globalConfig.GFWList.Content,
},
"DNSServer": map[string]interface{}{
"port": s.globalConfig.DNS.Port,
"UpstreamServers": s.globalConfig.DNS.UpstreamDNS,
"DNSSECUpstreamServers": s.globalConfig.DNS.DNSSECUpstreamDNS,
"saveInterval": s.globalConfig.DNS.SaveInterval,
"enableIPv6": s.globalConfig.DNS.EnableIPv6,
},
"HTTPServer": map[string]interface{}{
"port": s.globalConfig.HTTP.Port,
},
}
json.NewEncoder(w).Encode(config)
case http.MethodPost:
// 更新配置
var req struct {
DNSServer struct {
Port int `json:"port"`
UpstreamServers []string `json:"upstreamServers"`
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
Timeout int `json:"timeout"`
SaveInterval int `json:"saveInterval"`
EnableIPv6 bool `json:"enableIPv6"`
} `json:"dnsserver"`
HTTPServer struct {
Port int `json:"port"`
} `json:"httpserver"`
Shield struct {
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
Blacklists []config.BlacklistEntry `json:"blacklists"`
UpdateInterval int `json:"updateInterval"`
} `json:"shield"`
GFWList struct {
IP string `json:"ip"`
Content string `json:"content"`
} `json:"gfwList"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 更新DNS配置
if req.DNSServer.Port > 0 {
s.globalConfig.DNS.Port = req.DNSServer.Port
}
if len(req.DNSServer.UpstreamServers) > 0 {
s.globalConfig.DNS.UpstreamDNS = req.DNSServer.UpstreamServers
}
if len(req.DNSServer.DnssecUpstreamServers) > 0 {
s.globalConfig.DNS.DNSSECUpstreamDNS = req.DNSServer.DnssecUpstreamServers
}
if req.DNSServer.SaveInterval > 0 {
s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval
}
s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6
// 更新HTTP配置
if req.HTTPServer.Port > 0 {
s.globalConfig.HTTP.Port = req.HTTPServer.Port
}
// 更新屏蔽配置
if req.Shield.BlockMethod != "" {
// 验证屏蔽方法是否有效
validMethods := map[string]bool{
"NXDOMAIN": true,
"refused": true,
"emptyIP": true,
"customIP": true,
}
if !validMethods[req.Shield.BlockMethod] {
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
return
}
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
// 如果选择了customIP,验证IP地址
if req.Shield.BlockMethod == "customIP" {
if req.Shield.CustomBlockIP == "" {
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
return
}
// 简单的IP地址验证
if !isValidIP(req.Shield.CustomBlockIP) {
http.Error(w, "无效的IP地址", http.StatusBadRequest)
return
}
}
}
if req.Shield.CustomBlockIP != "" {
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
}
// 更新GFWList配置
s.globalConfig.GFWList.IP = req.GFWList.IP
s.globalConfig.GFWList.Content = req.GFWList.Content
// 重新加载GFWList规则
if s.gfwManager != nil {
if err := s.gfwManager.LoadRules(); err != nil {
logger.Error("重新加载GFWList规则失败", "error", err)
}
}
// 更新黑名单配置
if req.Shield.Blacklists != nil {
// 验证黑名单配置
for i, bl := range req.Shield.Blacklists {
if bl.URL == "" {
http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest)
return
}
if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") {
http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest)
return
}
}
s.globalConfig.Shield.Blacklists = req.Shield.Blacklists
s.shieldManager.UpdateBlacklist(req.Shield.Blacklists)
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
logger.Error("重新加载规则失败", "error", err)
}
}
// 更新更新间隔
if req.Shield.UpdateInterval > 0 {
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
// 重新启动自动更新
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
}
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
logger.Error("保存配置到文件失败", "error", err)
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
}
// 返回成功响应
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "配置已更新",
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// isValidIP 简单验证IP地址格式
func isValidIP(ip string) bool {
// 简单的IPv4地址验证
parts := strings.Split(ip, ".")
if len(parts) != 4 {
return false
}
for _, part := range parts {
// 检查是否为数字
for _, char := range part {
if char < '0' || char > '9' {
return false
}
}
// 检查数字范围
var num int
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
return false
}
}
return true
}
// checkURLExists 检查URL是否存在且可访问
func checkURLExists(url string) bool {
// 创建一个带有超时的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
}
// 发送HEAD请求来检查URL是否存在
resp, err := client.Head(url)
if err != nil {
return false
}
defer resp.Body.Close()
// 检查状态码,2xx和3xx表示成功
return resp.StatusCode >= 200 && resp.StatusCode < 400
}
// handleLogsStats 处理日志统计请求
func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取日志统计数据
logStats := s.dnsServer.GetQueryStats()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logStats)
}
// handleLogsQuery 处理日志查询请求
func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取查询参数
limit := 100 // 默认返回100条日志
offset := 0
sortField := r.URL.Query().Get("sort")
sortDirection := r.URL.Query().Get("direction")
resultFilter := r.URL.Query().Get("result")
searchTerm := r.URL.Query().Get("search")
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
fmt.Sscanf(limitStr, "%d", &limit)
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 获取日志数据
logs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logs)
}
// handleLogsCount 处理日志总数请求
func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取日志总数
count := s.dnsServer.GetQueryLogsCount()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
}
// handleRestart 处理重启服务请求
func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("收到重启服务请求")
// 停止DNS服务器
s.dnsServer.Stop()
// 重新加载屏蔽规则
if err := s.shieldManager.LoadRules(); err != nil {
logger.Error("重新加载屏蔽规则失败", "error", err)
}
// 重新启动DNS服务器
go func() {
if err := s.dnsServer.Start(); err != nil {
logger.Error("DNS服务器重启失败", "error", err)
}
}()
// 重新启动定时更新任务
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "服务已重启"})
logger.Info("服务重启成功")
}
// handleLogin 处理登录请求
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var loginData struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&loginData); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
return
}
// 验证用户名和密码
if loginData.Username != s.config.Username || loginData.Password != s.config.Password {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "用户名或密码错误"})
return
}
// 生成会话ID
sessionID := fmt.Sprintf("%d_%d", time.Now().UnixNano(), len(s.sessions))
// 保存会话
s.sessionsMutex.Lock()
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
s.sessionsMutex.Unlock()
// 设置Cookie
cookie := &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
Expires: time.Now().Add(s.sessionTTL),
HttpOnly: true,
Secure: false, // 开发环境下使用false,生产环境应使用true
}
http.SetCookie(w, cookie)
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "登录成功"})
logger.Info(fmt.Sprintf("用户 %s 登录成功", loginData.Username))
}
// handleLogout 处理注销请求
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 从Cookie中获取会话ID
cookie, err := r.Cookie("session_id")
if err == nil {
// 删除会话
s.sessionsMutex.Lock()
delete(s.sessions, cookie.Value)
s.sessionsMutex.Unlock()
}
// 清除Cookie
clearCookie := &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: false,
}
http.SetCookie(w, clearCookie)
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "注销成功"})
logger.Info("用户注销成功")
}
// handleChangePassword 处理修改密码请求
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var changePasswordData struct {
CurrentPassword string `json:"currentPassword"`
NewPassword string `json:"newPassword"`
}
if err := json.NewDecoder(r.Body).Decode(&changePasswordData); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
return
}
// 验证当前密码
if changePasswordData.CurrentPassword != s.config.Password {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "当前密码错误"})
return
}
// 更新密码
s.config.Password = changePasswordData.NewPassword
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "保存密码失败"})
return
}
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"})
logger.Info("密码修改成功")
}