Files
dns-server/http/server.go
T
Alex Yang f9e2e5a6bc update
2026-04-12 21:40:22 +08:00

3097 lines
91 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package http
import (
"bytes"
"encoding/csv"
"encoding/json"
"fmt"
"net/http"
"os"
"sort"
"strings"
"sync"
"time"
"dns-server/config"
"dns-server/dns"
"dns-server/gfw"
"dns-server/log"
"dns-server/logger"
"dns-server/shield"
"dns-server/threat"
"github.com/gorilla/websocket"
)
// CacheEntry 缓存条目
type CacheEntry struct {
data interface{} // 缓存的数据
timestamp time.Time // 缓存创建时间
hits int // 命中次数(用于 LRU
}
// QueryCache 查询结果缓存
type QueryCache struct {
data map[string]*CacheEntry // 缓存键 -> 缓存条目
mutex sync.RWMutex // 读写锁
maxSize int // 最大缓存条目数
ttl time.Duration // 缓存有效期
}
// StatsCache 统计数据缓存
type StatsCache struct {
data map[string]*CacheEntry // 缓存键 -> 缓存条目
mutex sync.RWMutex // 读写锁
maxSize int // 最大缓存条目数
ttl time.Duration // 缓存有效期
lastStats map[string]interface{} // 上次缓存的统计数据,用于增量更新
}
// NewQueryCache 创建查询结果缓存
func NewQueryCache(maxSize int, ttl time.Duration) *QueryCache {
cache := &QueryCache{
data: make(map[string]*CacheEntry),
maxSize: maxSize,
ttl: ttl,
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// NewStatsCache 创建统计数据缓存
func NewStatsCache(maxSize int, ttl time.Duration) *StatsCache {
cache := &StatsCache{
data: make(map[string]*CacheEntry),
maxSize: maxSize,
ttl: ttl,
lastStats: make(map[string]interface{}),
}
// 启动缓存清理协程
go cache.startCleanupLoop()
return cache
}
// Get 获取缓存条目
func (c *QueryCache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
entry, found := c.data[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
// 检查是否过期
if time.Since(entry.timestamp) > c.ttl {
// 过期,删除
c.mutex.Lock()
delete(c.data, key)
c.mutex.Unlock()
return nil, false
}
// 更新命中次数
c.mutex.Lock()
entry.hits++
c.mutex.Unlock()
return entry.data, true
}
// Set 设置缓存条目
func (c *QueryCache) Set(key string, data interface{}) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果缓存已满,删除最少使用的条目
if len(c.data) >= c.maxSize {
c.evictLRU()
}
c.data[key] = &CacheEntry{
data: data,
timestamp: time.Now(),
hits: 0,
}
}
// Delete 删除缓存条目
func (c *QueryCache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.data, key)
}
// Clear 清空缓存
func (c *QueryCache) Clear() {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data = make(map[string]*CacheEntry)
}
// evictLRU 淘汰最少使用的条目
func (c *QueryCache) evictLRU() {
var lruKey string
minHits := int(^uint(0) >> 1) // 最大 int 值
for key, entry := range c.data {
if entry.hits < minHits {
minHits = entry.hits
lruKey = key
}
}
if lruKey != "" {
delete(c.data, lruKey)
}
}
// startCleanupLoop 启动清理协程
func (c *QueryCache) startCleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期条目
func (c *QueryCache) cleanupExpired() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, entry := range c.data {
if now.Sub(entry.timestamp) > c.ttl {
delete(c.data, key)
}
}
}
// StatsCache 方法
// Get 获取统计数据缓存条目
func (c *StatsCache) Get(key string) (map[string]interface{}, bool) {
c.mutex.RLock()
entry, found := c.data[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
// 检查是否过期
if time.Since(entry.timestamp) > c.ttl {
// 过期,删除
c.mutex.Lock()
delete(c.data, key)
c.mutex.Unlock()
return nil, false
}
// 更新命中次数
c.mutex.Lock()
entry.hits++
c.mutex.Unlock()
if data, ok := entry.data.(map[string]interface{}); ok {
return data, true
}
return nil, false
}
// Set 设置统计数据缓存条目
func (c *StatsCache) Set(key string, data map[string]interface{}) {
c.mutex.Lock()
defer c.mutex.Unlock()
// 如果缓存已满,删除最少使用的条目
if len(c.data) >= c.maxSize {
c.evictLRU()
}
c.data[key] = &CacheEntry{
data: data,
timestamp: time.Now(),
hits: 0,
}
// 保存最后统计数据
c.lastStats = data
}
// GetLastStats 获取上次缓存的统计数据
func (c *StatsCache) GetLastStats() map[string]interface{} {
c.mutex.RLock()
defer c.mutex.RUnlock()
return c.lastStats
}
// Delete 删除统计数据缓存条目
func (c *StatsCache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.data, key)
}
// Clear 清空统计数据缓存
func (c *StatsCache) Clear() {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data = make(map[string]*CacheEntry)
c.lastStats = make(map[string]interface{})
}
// evictLRU 淘汰最少使用的条目
func (c *StatsCache) evictLRU() {
var lruKey string
minHits := int(^uint(0) >> 1) // 最大 int 值
for key, entry := range c.data {
if entry.hits < minHits {
minHits = entry.hits
lruKey = key
}
}
if lruKey != "" {
delete(c.data, lruKey)
}
}
// startCleanupLoop 启动清理协程
func (c *StatsCache) startCleanupLoop() {
ticker := time.NewTicker(c.ttl / 2)
defer ticker.Stop()
for range ticker.C {
c.cleanupExpired()
}
}
// cleanupExpired 清理过期条目
func (c *StatsCache) cleanupExpired() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, entry := range c.data {
if now.Sub(entry.timestamp) > c.ttl {
delete(c.data, key)
}
}
}
// ClearQueryCache 清除查询缓存
func (s *Server) ClearQueryCache() {
if s.queryCache != nil {
s.queryCache.Clear()
}
}
// ClearStatsCache 清除统计缓存
func (s *Server) ClearStatsCache() {
if s.statsCache != nil {
s.statsCache.Clear()
}
}
// ClearAllCache 清除所有缓存
func (s *Server) ClearAllCache() {
s.ClearQueryCache()
s.ClearStatsCache()
}
// Server HTTP控制台服务器
type Server struct {
globalConfig *config.Config
config *config.HTTPConfig
dnsServer *dns.Server
shieldManager *shield.ShieldManager
gfwManager *gfw.GFWListManager
domainInfoManager *shield.DomainInfoManager
server *http.Server
// 会话管理相关字段
sessions map[string]time.Time // 会话 ID 到过期时间的映射
sessionsMutex sync.Mutex // 会话映射的互斥锁
sessionTTL time.Duration // 会话过期时间
// WebSocket 相关字段
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
clientsMutex sync.Mutex
broadcastChan chan []byte
// 查询缓存相关字段
queryCache *QueryCache // 查询结果缓存
statsCache *StatsCache // 统计数据缓存
cacheEnabled bool // 缓存是否启用
cacheTTL time.Duration // 缓存过期时间
cacheMaxSize int // 缓存最大条目数
}
// NewServer 创建 HTTP 服务器实例
func NewServer(globalConfig *config.Config, dnsServer *dns.Server, shieldManager *shield.ShieldManager, gfwManager *gfw.GFWListManager) *Server {
// 创建域名信息管理器
domainInfoManager := shield.NewDomainInfoManager(&globalConfig.DomainInfo)
server := &Server{
globalConfig: globalConfig,
config: &globalConfig.HTTP,
dnsServer: dnsServer,
shieldManager: shieldManager,
gfwManager: gfwManager,
domainInfoManager: domainInfoManager,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// 允许所有 CORS 请求
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
broadcastChan: make(chan []byte, 100),
// 会话管理初始化
sessions: make(map[string]time.Time),
sessionTTL: 24 * time.Hour, // 会话有效期 24 小时
// 查询缓存初始化
queryCache: NewQueryCache(100, 5*time.Second), // 最多 100 条,5 秒过期
statsCache: NewStatsCache(10, 2*time.Second), // 最多 10 条,2 秒过期
cacheEnabled: true, // 默认启用缓存
cacheTTL: 5 * time.Second, // 默认缓存 5 秒
cacheMaxSize: 100, // 默认最大 100 条
}
// 启动域名信息管理器
domainInfoManager.Start()
// 启动广播协程
go server.startBroadcastLoop()
// 启动会话清理协程
go server.cleanupSessionsLoop()
return server
}
// Start 启动HTTP服务器
func (s *Server) Start() error {
mux := http.NewServeMux()
// 登录路由,不需要认证
mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
// 重定向到登录页面HTML
http.Redirect(w, r, "/login.html", http.StatusFound)
})
// API路由
if s.config.EnableAPI {
// 登录API端点,不需要认证
mux.HandleFunc("/api/login", s.handleLogin)
// 注销API端点,不需要认证
mux.HandleFunc("/api/logout", s.handleLogout)
// 修改密码API端点,需要认证
mux.HandleFunc("/api/change-password", s.loginRequired(s.handleChangePassword))
// 重定向/api到Swagger UI页面
mux.HandleFunc("/api", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/api/index.html", http.StatusMovedPermanently)
}))
// 注册所有API端点,应用登录中间件
mux.HandleFunc("/api/stats", s.loginRequired(s.handleStats))
mux.HandleFunc("/api/shield", s.loginRequired(s.handleShield))
mux.HandleFunc("/api/shield/localrules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
localRules := s.shieldManager.GetLocalRules()
json.NewEncoder(w).Encode(localRules)
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}))
mux.HandleFunc("/api/shield/remoterules", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet {
remoteRules := s.shieldManager.GetRemoteRules()
json.NewEncoder(w).Encode(remoteRules)
return
}
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}))
mux.HandleFunc("/api/shield/hosts", s.loginRequired(s.handleShieldHosts))
mux.HandleFunc("/api/shield/blacklists", s.loginRequired(s.handleShieldBlacklists))
// 传统查询接口(保持向后兼容)
mux.HandleFunc("/api/query", s.loginRequired(s.handleQuery))
// RESTful 域名查询接口
mux.HandleFunc("/api/domains/", s.loginRequired(s.handleDomainQuery))
mux.HandleFunc("/api/status", s.loginRequired(s.handleStatus))
mux.HandleFunc("/api/config", s.loginRequired(s.handleConfig))
mux.HandleFunc("/api/config/restart", s.loginRequired(s.handleRestart))
// 添加统计相关接口
mux.HandleFunc("/api/top-blocked", s.loginRequired(s.handleTopBlockedDomains))
mux.HandleFunc("/api/top-resolved", s.loginRequired(s.handleTopResolvedDomains))
mux.HandleFunc("/api/top-clients", s.loginRequired(s.handleTopClients))
mux.HandleFunc("/api/top-domains", s.loginRequired(s.handleTopDomains))
mux.HandleFunc("/api/recent-blocked", s.loginRequired(s.handleRecentBlockedDomains))
mux.HandleFunc("/api/hourly-stats", s.loginRequired(s.handleHourlyStats))
mux.HandleFunc("/api/daily-stats", s.loginRequired(s.handleDailyStats))
mux.HandleFunc("/api/monthly-stats", s.loginRequired(s.handleMonthlyStats))
mux.HandleFunc("/api/query/type", s.loginRequired(s.handleQueryTypeStats))
// 日志统计相关接口
mux.HandleFunc("/api/logs/stats", s.loginRequired(s.handleLogsStats))
mux.HandleFunc("/api/logs/query", s.loginRequired(s.handleLogsQuery))
mux.HandleFunc("/api/logs/count", s.loginRequired(s.handleLogsCount))
// 归档管理接口
mux.HandleFunc("/api/logs/archives", s.loginRequired(s.handleArchiveList))
mux.HandleFunc("/api/logs/archive-cleanup", s.loginRequired(s.handleArchiveCleanup))
// 域名信息列表接口
mux.HandleFunc("/api/domain-info", s.loginRequired(s.handleDomainInfoList))
// 域名信息更新接口
mux.HandleFunc("/api/domain-info/update", s.loginRequired(s.handleDomainInfoUpdate))
mux.HandleFunc("/api/domain-info/update/{type}", s.loginRequired(s.handleDomainInfoUpdateByType))
// 域名信息状态接口
mux.HandleFunc("/api/domain-info/status", s.loginRequired(s.handleDomainInfoStatus))
// 域名信息缓存刷新接口
mux.HandleFunc("/api/domain-info/refresh", s.loginRequired(s.handleDomainInfoRefresh))
// 域名信息列表添加接口
mux.HandleFunc("/api/domain-info/add", s.loginRequired(s.handleDomainInfoAdd))
// 域名信息列表删除接口
mux.HandleFunc("/api/domain-info/remove", s.loginRequired(s.handleDomainInfoRemove))
// 威胁查询接口
mux.HandleFunc("/api/threat", s.loginRequired(s.handleThreatQuery))
// 威胁批量查询接口
mux.HandleFunc("/api/threat/batch", s.loginRequired(s.handleThreatBatch))
// 威胁告警接口
mux.HandleFunc("/api/alert", s.loginRequired(s.handleAlert))
// 威胁告警解决接口
mux.HandleFunc("/api/alert/resolve", s.loginRequired(s.handleAlertResolve))
// 威胁域名管理接口
mux.HandleFunc("/api/threat/domain", s.loginRequired(s.handleThreatDomain))
// WebSocket 端点
mux.HandleFunc("/ws/stats", s.loginRequired(s.handleWebSocketStats))
// 将/api/下的静态文件服务指向static/api目录,放在最后以避免覆盖API端点
apiFileServer := http.FileServer(http.Dir("./static/api"))
mux.Handle("/api/", s.loginRequired(http.StripPrefix("/api", apiFileServer).ServeHTTP))
}
// 自定义静态文件服务处理器,用于禁用浏览器缓存,放在API路由之后
fileServer := http.FileServer(http.Dir("./static"))
// 单独处理login.html,不需要登录
mux.HandleFunc("/login.html", func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 直接提供login.html文件
http.ServeFile(w, r, "./static/login.html")
})
// Tracker目录静态文件服务
trackerFileServer := http.FileServer(http.Dir("./tracker"))
mux.HandleFunc("/tracker/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 使用StripPrefix处理路径
http.StripPrefix("/tracker", trackerFileServer).ServeHTTP(w, r)
}))
// 其他静态文件需要登录
mux.HandleFunc("/", s.loginRequired(func(w http.ResponseWriter, r *http.Request) {
// 添加Cache-Control头,禁用浏览器缓存
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 GMT")
// 使用StripPrefix处理路径
http.StripPrefix("/", fileServer).ServeHTTP(w, r)
}))
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
logger.Info(fmt.Sprintf("HTTP控制台服务器启动,监听地址: %s:%d", s.config.Host, s.config.Port))
return s.server.ListenAndServe()
}
// Stop 停止HTTP服务器
func (s *Server) Stop() {
if s.server != nil {
s.server.Close()
}
logger.Info("HTTP控制台服务器已停止")
}
// handleStats 处理统计信息请求
// @Summary 获取系统统计信息
// @Description 获取DNS服务器和Shield的统计信息
// @Tags stats
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "统计信息"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/stats [get]
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型(如果有)
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间为两位小数
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 计算DNSSEC使用率
dnssecUsage := float64(0)
if dnsStats.Queries > 0 {
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
}
// 构建响应数据,确保所有字段都反映服务器的真实状态
stats := map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
"DNSSECQueries": dnsStats.DNSSECQueries,
"DNSSECSuccess": dnsStats.DNSSECSuccess,
"DNSSECFailed": dnsStats.DNSSECFailed,
"DNSSECEnabled": dnsStats.DNSSECEnabled,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"dnssecEnabled": dnsStats.DNSSECEnabled,
"dnssecQueries": dnsStats.DNSSECQueries,
"dnssecSuccess": dnsStats.DNSSECSuccess,
"dnssecFailed": dnsStats.DNSSECFailed,
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
"time": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}
// WebSocket相关方法
// handleWebSocketStats 处理WebSocket连接,用于实时推送统计数据
func (s *Server) handleWebSocketStats(w http.ResponseWriter, r *http.Request) {
// 升级HTTP连接为WebSocket
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error(fmt.Sprintf("WebSocket升级失败: %v", err))
return
}
defer conn.Close()
// 将新客户端添加到客户端列表
s.clientsMutex.Lock()
s.clients[conn] = true
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("新WebSocket客户端连接,当前连接数: %d", clientCount))
// 发送初始数据
if err := s.sendInitialStats(conn); err != nil {
logger.Error(fmt.Sprintf("发送初始数据失败: %v", err))
return
}
// 定期发送更新数据
ticker := time.NewTicker(500 * time.Millisecond) // 每500ms检查一次数据变化
defer ticker.Stop()
// 最后一次发送的数据快照,用于检测变化
var lastStats map[string]interface{}
// 保持连接并定期发送数据
for {
select {
case <-ticker.C:
// 获取最新统计数据
currentStats := s.buildStatsData()
// 检查数据是否有变化
if !s.areStatsEqual(lastStats, currentStats) {
// 数据有变化,发送更新
data, err := json.Marshal(map[string]interface{}{
"type": "stats_update",
"data": currentStats,
"time": time.Now(),
})
if err != nil {
logger.Error(fmt.Sprintf("序列化统计数据失败: %v", err))
continue
}
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
logger.Error(fmt.Sprintf("发送WebSocket消息失败: %v", err))
return
}
// 更新最后发送的数据
lastStats = currentStats
}
case <-r.Context().Done():
// 客户端断开连接
s.clientsMutex.Lock()
delete(s.clients, conn)
clientCount := len(s.clients)
s.clientsMutex.Unlock()
logger.Info(fmt.Sprintf("WebSocket客户端断开连接,当前连接数: %d", clientCount))
return
}
}
}
// sendInitialStats 发送初始统计数据
func (s *Server) sendInitialStats(conn *websocket.Conn) error {
stats := s.buildStatsData()
data, err := json.Marshal(map[string]interface{}{
"type": "initial_data",
"data": stats,
"time": time.Now(),
})
if err != nil {
return err
}
return conn.WriteMessage(websocket.TextMessage, data)
}
// buildStatsData 构建统计数据
func (s *Server) buildStatsData() map[string]interface{} {
dnsStats := s.dnsServer.GetStats()
shieldStats := s.shieldManager.GetStats()
// 获取最常用查询类型
topQueryType := "-"
maxCount := int64(0)
if len(dnsStats.QueryTypes) > 0 {
for queryType, count := range dnsStats.QueryTypes {
if count > maxCount {
maxCount = count
topQueryType = queryType
}
}
}
// 获取活跃来源IP数量
activeIPCount := len(dnsStats.SourceIPs)
// 格式化平均响应时间
formattedResponseTime := float64(int(dnsStats.AvgResponseTime*100)) / 100
// 计算DNSSEC使用率
dnssecUsage := float64(0)
if dnsStats.Queries > 0 {
dnssecUsage = float64(dnsStats.DNSSECQueries) / float64(dnsStats.Queries) * 100
}
return map[string]interface{}{
"dns": map[string]interface{}{
"Queries": dnsStats.Queries,
"Blocked": dnsStats.Blocked,
"Allowed": dnsStats.Allowed,
"Errors": dnsStats.Errors,
"LastQuery": dnsStats.LastQuery,
"AvgResponseTime": formattedResponseTime,
"TotalResponseTime": dnsStats.TotalResponseTime,
"QueryTypes": dnsStats.QueryTypes,
"SourceIPs": dnsStats.SourceIPs,
"CpuUsage": dnsStats.CpuUsage,
"DNSSECQueries": dnsStats.DNSSECQueries,
"DNSSECSuccess": dnsStats.DNSSECSuccess,
"DNSSECFailed": dnsStats.DNSSECFailed,
"DNSSECEnabled": dnsStats.DNSSECEnabled,
},
"shield": shieldStats,
"topQueryType": topQueryType,
"activeIPs": activeIPCount,
"avgResponseTime": formattedResponseTime,
"cpuUsage": dnsStats.CpuUsage,
"dnssecEnabled": dnsStats.DNSSECEnabled,
"dnssecQueries": dnsStats.DNSSECQueries,
"dnssecSuccess": dnsStats.DNSSECSuccess,
"dnssecFailed": dnsStats.DNSSECFailed,
"dnssecUsage": float64(int(dnssecUsage*100)) / 100, // 保留两位小数
}
}
// areStatsEqual 检查两次统计数据是否相等(用于检测变化)
func (s *Server) areStatsEqual(stats1, stats2 map[string]interface{}) bool {
if stats1 == nil || stats2 == nil {
return false
}
// 只比较关键数值,避免频繁更新
if dns1, ok1 := stats1["dns"].(map[string]interface{}); ok1 {
if dns2, ok2 := stats2["dns"].(map[string]interface{}); ok2 {
// 检查主要计数器
if dns1["Queries"] != dns2["Queries"] ||
dns1["Blocked"] != dns2["Blocked"] ||
dns1["Allowed"] != dns2["Allowed"] ||
dns1["Errors"] != dns2["Errors"] ||
dns1["AvgResponseTime"] != dns2["AvgResponseTime"] {
return false
}
}
}
return true
}
// startBroadcastLoop 启动广播循环
func (s *Server) startBroadcastLoop() {
for message := range s.broadcastChan {
s.clientsMutex.Lock()
for client := range s.clients {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
logger.Error(fmt.Sprintf("广播消息失败: %v", err))
client.Close()
delete(s.clients, client)
}
}
s.clientsMutex.Unlock()
}
}
// cleanupSessionsLoop 定期清理过期会话
func (s *Server) cleanupSessionsLoop() {
for {
time.Sleep(1 * time.Hour) // 每小时清理一次
s.sessionsMutex.Lock()
now := time.Now()
for sessionID, expiryTime := range s.sessions {
if now.After(expiryTime) {
delete(s.sessions, sessionID)
}
}
s.sessionsMutex.Unlock()
}
}
// isAuthenticated 检查用户是否已认证
func (s *Server) isAuthenticated(r *http.Request) bool {
// 从Cookie中获取会话ID
cookie, err := r.Cookie("session_id")
if err != nil {
return false
}
sessionID := cookie.Value
s.sessionsMutex.Lock()
defer s.sessionsMutex.Unlock()
// 检查会话是否存在且未过期
expiryTime, exists := s.sessions[sessionID]
if !exists {
return false
}
if time.Now().After(expiryTime) {
// 会话已过期,删除它
delete(s.sessions, sessionID)
return false
}
// 延长会话有效期
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
return true
}
// loginRequired 登录中间件
func (s *Server) loginRequired(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 检查是否为登录页面或登录API,允许直接访问
if r.URL.Path == "/login" || r.URL.Path == "/api/login" {
next.ServeHTTP(w, r)
return
}
// 检查是否已认证
if !s.isAuthenticated(r) {
// 如果是API请求,返回401错误
if strings.HasPrefix(r.URL.Path, "/api/") || strings.HasPrefix(r.URL.Path, "/ws/") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "未授权访问"})
return
}
// 否则重定向到登录页面
http.Redirect(w, r, "/login", http.StatusFound)
return
}
// 已认证,继续处理请求
next.ServeHTTP(w, r)
}
}
// handleTopBlockedDomains 处理TOP屏蔽域名请求
func (s *Server) handleTopBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 返回最近30天的所有域名,设置合理上限50
domains := s.dnsServer.GetTopBlockedDomains(50)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopResolvedDomains 处理获取最常解析的域名请求
func (s *Server) handleTopResolvedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetTopResolvedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"count": domain.Count,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleRecentBlockedDomains 处理最近屏蔽域名请求
func (s *Server) handleRecentBlockedDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domains := s.dnsServer.GetRecentBlockedDomains(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(domains))
for i, domain := range domains {
result[i] = map[string]interface{}{
"domain": domain.Domain,
"time": time.Unix(domain.LastSeen, 0).Format("15:04:05"),
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleHourlyStats 处理24小时统计请求
func (s *Server) handleHourlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
hourlyStats := s.dnsServer.GetHourlyStats()
// 生成最近24小时的数据
now := time.Now()
labels := make([]string, 24)
data := make([]int64, 24)
for i := 23; i >= 0; i-- {
hour := now.Add(time.Duration(-i) * time.Hour)
hourKey := hour.Format("2006-01-02-15")
labels[23-i] = hour.Format("15:00")
data[23-i] = hourlyStats[hourKey]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleDailyStats 处理每日统计数据请求
func (s *Server) handleDailyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去7天的时间标签
labels := make([]string, 7)
data := make([]int64, 7)
now := time.Now()
for i := 6; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[6-i] = t.Format("01-02")
data[6-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleMonthlyStats 处理每月统计数据请求
func (s *Server) handleMonthlyStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取每日统计数据(用于30天视图)
dailyStats := s.dnsServer.GetDailyStats()
// 生成过去30天的时间标签
labels := make([]string, 30)
data := make([]int64, 30)
now := time.Now()
for i := 29; i >= 0; i-- {
t := now.AddDate(0, 0, -i)
key := t.Format("2006-01-02")
labels[29-i] = t.Format("01-02")
data[29-i] = dailyStats[key]
}
result := map[string]interface{}{
"labels": labels,
"data": data,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleQueryTypeStats 处理查询类型统计请求
func (s *Server) handleQueryTypeStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取DNS统计数据
dnsStats := s.dnsServer.GetStats()
// 转换为前端需要的格式
result := make([]map[string]interface{}, 0, len(dnsStats.QueryTypes))
for queryType, count := range dnsStats.QueryTypes {
result = append(result, map[string]interface{}{
"type": queryType,
"count": count,
})
}
// 按计数降序排序
sort.Slice(result, func(i, j int) bool {
return result[i]["count"].(int64) > result[j]["count"].(int64)
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopClients 处理TOP客户端请求
func (s *Server) handleTopClients(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取TOP客户端列表
clients := s.dnsServer.GetTopClients(10)
// 转换为前端需要的格式
result := make([]map[string]interface{}, len(clients))
for i, client := range clients {
result[i] = map[string]interface{}{
"ip": client.IP,
"count": client.Count,
"lastSeen": client.LastSeen,
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTopDomains 处理TOP域名请求
func (s *Server) handleTopDomains(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取TOP被屏蔽域名,返回最近30天的数据,设置合理上限50
blockedDomains := s.dnsServer.GetTopBlockedDomains(50)
// 获取TOP已解析域名,返回最近30天的数据,设置合理上限50
resolvedDomains := s.dnsServer.GetTopResolvedDomains(50)
// 合并并去重域名统计
domainMap := make(map[string]int64)
dnssecStatusMap := make(map[string]bool)
for _, domain := range blockedDomains {
domainMap[domain.Domain] += domain.Count
dnssecStatusMap[domain.Domain] = domain.DNSSEC
}
for _, domain := range resolvedDomains {
domainMap[domain.Domain] += domain.Count
dnssecStatusMap[domain.Domain] = domain.DNSSEC
}
// 转换为切片并排序
domainList := make([]map[string]interface{}, 0, len(domainMap))
for domain, count := range domainMap {
dnssec, hasDNSSEC := dnssecStatusMap[domain]
domainList = append(domainList, map[string]interface{}{
"domain": domain,
"count": count,
"dnssec": hasDNSSEC && dnssec,
})
}
// 按计数降序排序
sort.Slice(domainList, func(i, j int) bool {
return domainList[i]["count"].(int64) > domainList[j]["count"].(int64)
})
// 返回所有合并后的域名,设置合理上限50
if len(domainList) > 50 {
domainList = domainList[:50]
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(domainList)
}
// handleShield 处理Shield相关操作
// @Summary 管理Shield配置
// @Description 获取或更新Shield的配置信息
// @Tags shield
// @Accept json
// @Produce json
// @Param config body map[string]interface{} false "Shield配置信息"
// @Success 200 {object} map[string]interface{} "配置信息"
// @Failure 400 {object} map[string]string "请求参数错误"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/shield [get]
// @Router /api/shield [post]
func (s *Server) handleShield(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 默认处理逻辑
switch r.Method {
case http.MethodGet:
// 检查是否需要返回完整规则列表
if r.URL.Query().Get("all") == "true" {
// 返回完整规则数据
rules := s.shieldManager.GetRules()
json.NewEncoder(w).Encode(rules)
return
}
// 获取规则统计信息
stats := s.shieldManager.GetStats()
shieldInfo := map[string]interface{}{
"updateInterval": s.globalConfig.Shield.UpdateInterval,
"blockMethod": s.globalConfig.Shield.BlockMethod,
"blacklistCount": len(s.globalConfig.Shield.Blacklists),
"domainRulesCount": stats["domainRules"],
"domainExceptionsCount": stats["domainExceptions"],
"regexRulesCount": stats["regexRules"],
"regexExceptionsCount": stats["regexExceptions"],
"hostsRulesCount": stats["hostsRules"],
}
json.NewEncoder(w).Encode(shieldInfo)
return
}
switch r.Method {
case http.MethodPost:
// 添加屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 清空DNS缓存,使新规则立即生效
s.dnsServer.DnsCache.Clear()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodDelete:
// 删除屏蔽规则
var req struct {
Rule string `json:"rule"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveRule(req.Rule); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 清空DNS缓存,使规则变更立即生效
s.dnsServer.DnsCache.Clear()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
case http.MethodPut:
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "规则重新加载成功"})
return
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
}
// handleShieldBlacklists 处理黑名单相关操作
// @Summary 管理黑名单
// @Description 处理黑名单的CRUD操作,包括获取列表、添加、更新和删除黑名单
// @Tags shield
// @Accept json
// @Produce json
// @Param name path string false "黑名单名称(用于删除操作)"
// @Param blacklist body map[string]interface{} false "黑名单信息(用于添加/更新操作)"
// @Success 200 {object} map[string]interface{} "操作成功"
// @Failure 400 {object} map[string]string "请求参数错误"
// @Failure 404 {object} map[string]string "黑名单不存在"
// @Failure 500 {object} map[string]string "服务器内部错误"
// @Router /api/shield/blacklists [get]
// @Router /api/shield/blacklists [post]
// @Router /api/shield/blacklists [put]
// @Router /api/shield/blacklists/{name} [delete]
func (s *Server) handleShieldBlacklists(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// 处理更新单个黑名单
if strings.Contains(r.URL.Path, "/update") {
if r.Method == http.MethodPost {
// 提取黑名单URL或Name
parts := strings.Split(r.URL.Path, "/")
var targetURLOrName string
for i, part := range parts {
if part == "blacklists" && i+1 < len(parts) && parts[i+1] != "update" {
targetURLOrName = parts[i+1]
break
}
}
if targetURLOrName == "" {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单标识不能为空"})
return
}
// 获取黑名单列表
blacklists := s.shieldManager.GetBlacklists()
var targetIndex = -1
for i, list := range blacklists {
if list.URL == targetURLOrName || list.Name == targetURLOrName {
targetIndex = i
break
}
}
if targetIndex == -1 {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单不存在"})
return
}
// 更新时间戳
blacklists[targetIndex].LastUpdateTime = time.Now().Format(time.RFC3339)
// 保存更新后的黑名单列表
s.shieldManager.UpdateBlacklist(blacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则以获取最新的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
}
// 处理删除黑名单
parts := strings.Split(r.URL.Path, "/")
if len(parts) > 4 && parts[3] == "blacklists" && parts[4] != "" && r.Method == http.MethodDelete {
id := parts[4]
blacklists := s.shieldManager.GetBlacklists()
var newBlacklists []config.BlacklistEntry
for _, list := range blacklists {
if list.URL != id && list.Name != id {
newBlacklists = append(newBlacklists, list)
}
}
s.shieldManager.UpdateBlacklist(newBlacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
return
}
switch r.Method {
case http.MethodGet:
// 获取远程黑名单列表
blacklists := s.shieldManager.GetBlacklists()
json.NewEncoder(w).Encode(blacklists)
case http.MethodPost:
// 添加远程黑名单
var req struct {
Name string `json:"name"`
URL string `json:"url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
return
}
if req.Name == "" || req.URL == "" {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "名称和URL不能为空"})
return
}
// 获取现有黑名单
blacklists := s.shieldManager.GetBlacklists()
// 检查是否已存在
for _, list := range blacklists {
if list.URL == req.URL {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "黑名单URL已存在"})
return
}
}
// 检查URL是否存在且可访问
if !checkURLExists(req.URL) {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "URL不存在或无法访问"})
return
}
// 添加新黑名单
newEntry := config.BlacklistEntry{
Name: req.Name,
URL: req.URL,
Enabled: true,
}
blacklists = append(blacklists, newEntry)
s.shieldManager.UpdateBlacklist(blacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = blacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则以获取新添加的远程规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodPut:
// 更新黑名单列表(包括启用/禁用状态)
var updatedBlacklists []struct {
Name string `json:"Name" json:"name"`
URL string `json:"URL" json:"url"`
Enabled bool `json:"Enabled" json:"enabled"`
LastUpdateTime string `json:"LastUpdateTime" json:"lastUpdateTime"`
}
if err := json.NewDecoder(r.Body).Decode(&updatedBlacklists); err != nil {
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "无效的请求体"})
return
}
// 转换为config.BlacklistEntry类型
var newBlacklists []config.BlacklistEntry
for _, entry := range updatedBlacklists {
newBlacklists = append(newBlacklists, config.BlacklistEntry{
Name: entry.Name,
URL: entry.URL,
Enabled: entry.Enabled,
LastUpdateTime: entry.LastUpdateTime,
})
}
// 更新黑名单
s.shieldManager.UpdateBlacklist(newBlacklists)
// 更新全局配置中的黑名单
s.globalConfig.Shield.Blacklists = newBlacklists
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "config.ini"); err != nil {
logger.Error("保存配置文件失败", "error", err)
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "保存配置失败"})
return
}
// 重新加载规则
s.shieldManager.LoadRules()
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
default:
json.NewEncoder(w).Encode(map[string]string{"status": "error", "message": "Method not allowed"})
}
}
// handleShieldHosts 处理hosts管理请求
func (s *Server) handleShieldHosts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
// 添加hosts条目
var req struct {
IP string `json:"ip"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" || req.Domain == "" {
http.Error(w, "IP and Domain are required", http.StatusBadRequest)
return
}
if err := s.shieldManager.AddHostsEntry(req.IP, req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除hosts条目
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if err := s.shieldManager.RemoveHostsEntry(req.Domain); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodGet:
// 获取hosts条目列表
hosts := s.shieldManager.GetAllHosts()
hostsCount := s.shieldManager.GetHostsCount()
// 转换为数组格式,便于前端展示
hostsList := make([]map[string]string, 0, len(hosts))
for domain, ip := range hosts {
hostsList = append(hostsList, map[string]string{
"domain": domain,
"ip": ip,
})
}
json.NewEncoder(w).Encode(map[string]interface{}{
"hosts": hostsList,
"hostsCount": hostsCount,
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleQuery 处理DNS查询请求(传统接口,保持向后兼容)
func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
domain := r.URL.Query().Get("domain")
if domain == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
// 获取域名屏蔽的详细信息
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
// 添加时间戳
blockDetails["timestamp"] = time.Now()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(blockDetails)
}
// handleDomainQuery 处理RESTful风格的域名查询请求
func (s *Server) handleDomainQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 从URL路径中提取域名参数
// 路径格式: /api/domains/{domain}
path := r.URL.Path
parts := strings.Split(path, "/")
if len(parts) < 4 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
domain := parts[3]
if domain == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供domain参数"})
return
}
// 获取域名屏蔽的详细信息
blockDetails := s.shieldManager.CheckDomainBlockDetails(domain)
// 构建RESTful风格的响应
response := map[string]interface{}{
"domain": domain,
"status": blockDetails["blocked"],
"timestamp": time.Now(),
"details": blockDetails,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleStatus 处理系统状态请求
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := s.dnsServer.GetStats()
// 使用服务器的实际启动时间计算准确的运行时间
// 由于已移除旧日志系统,使用当前时间作为启动时间
serverStartTime := time.Now()
uptime := time.Duration(0)
// 构建包含所有真实服务器统计数据的响应
status := map[string]interface{}{
"status": "running",
"queries": stats.Queries,
"blocked": stats.Blocked,
"allowed": stats.Allowed,
"errors": stats.Errors,
"lastQuery": stats.LastQuery,
"avgResponseTime": stats.AvgResponseTime,
"activeIPs": len(stats.SourceIPs),
"startTime": serverStartTime,
"uptime": uptime.Milliseconds(), // 转换为毫秒数,方便前端处理
"cpuUsage": stats.CpuUsage,
"timestamp": time.Now(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
// handleConfig 处理配置请求
func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 每次从配置文件重新读取最新配置
cfg, err := config.LoadConfig("config.ini")
if err != nil {
logger.Error("加载配置文件失败", "error", err)
// 如果加载失败,返回内存中的配置
cfg = s.globalConfig
}
// 返回当前配置(包括黑名单配置)
// 注意:key 名必须与前端期望的一致
config := map[string]interface{}{
"Shield": map[string]interface{}{
"blockMethod": cfg.Shield.BlockMethod,
"customBlockIP": cfg.Shield.CustomBlockIP,
"blacklists": cfg.Shield.Blacklists,
"updateInterval": cfg.Shield.UpdateInterval,
"statsSaveInterval": cfg.Shield.StatsSaveInterval,
},
"GFWList": map[string]interface{}{
"ip": cfg.GFWList.IP,
"content": cfg.GFWList.Content,
"enabled": cfg.GFWList.Enabled,
},
"DNSServer": map[string]interface{}{
"port": cfg.DNS.Port,
"QueryMode": cfg.DNS.QueryMode,
"UpstreamServers": cfg.DNS.UpstreamDNS,
"DNSSECUpstreamServers": cfg.DNS.DNSSECUpstreamDNS,
"saveInterval": cfg.DNS.SaveInterval,
"queryTimeout": cfg.DNS.QueryTimeout,
"enableIPv6": cfg.DNS.EnableIPv6,
"enableDNSSEC": cfg.DNS.EnableDNSSEC,
"enableFastReturn": cfg.DNS.EnableFastReturn,
"noDNSSECDomains": cfg.DNS.NoDNSSECDomains,
"CacheMode": cfg.DNS.CacheMode,
"CacheSize": cfg.DNS.CacheSize,
"MaxCacheTTL": cfg.DNS.MaxCacheTTL,
"MinCacheTTL": cfg.DNS.MinCacheTTL,
"domainSpecificDNS": cfg.DNS.DomainSpecificDNS,
},
"HTTPServer": map[string]interface{}{
"port": cfg.HTTP.Port,
"host": cfg.HTTP.Host,
"enableAPI": cfg.HTTP.EnableAPI,
"username": cfg.HTTP.Username,
"password": cfg.HTTP.Password,
},
}
json.NewEncoder(w).Encode(config)
case http.MethodPost:
// 更新配置
var req struct {
DNSServer struct {
Port int `json:"port"`
QueryMode string `json:"queryMode"`
UpstreamServers []string `json:"upstreamServers"`
DnssecUpstreamServers []string `json:"dnssecUpstreamServers"`
Timeout int `json:"timeout"`
SaveInterval int `json:"saveInterval"`
EnableIPv6 bool `json:"enableIPv6"`
EnableDNSSEC bool `json:"enableDNSSEC"`
EnableFastReturn *bool `json:"enableFastReturn"`
NoDNSSECDomains []string `json:"noDNSSECDomains"`
CacheMode string `json:"cacheMode"`
CacheSize int `json:"cacheSize"`
MaxCacheTTL int `json:"maxCacheTTL"`
MinCacheTTL int `json:"minCacheTTL"`
DomainSpecificDNS map[string][]string `json:"domainSpecificDNS"`
} `json:"dnsserver"`
HTTPServer struct {
Port int `json:"port"`
Host string `json:"host"`
EnableAPI bool `json:"enableAPI"`
Username string `json:"username"`
Password string `json:"password"`
} `json:"httpserver"`
Shield struct {
BlockMethod string `json:"blockMethod"`
CustomBlockIP string `json:"customBlockIP"`
Blacklists []config.BlacklistEntry `json:"blacklists"`
UpdateInterval int `json:"updateInterval"`
StatsSaveInterval int `json:"statsSaveInterval"`
} `json:"shield"`
GFWList struct {
IP string `json:"ip"`
Content string `json:"content"`
Enabled bool `json:"enabled"`
} `json:"gfwList"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 更新DNS配置
if req.DNSServer.Port > 0 {
s.globalConfig.DNS.Port = req.DNSServer.Port
}
if len(req.DNSServer.UpstreamServers) > 0 {
s.globalConfig.DNS.UpstreamDNS = req.DNSServer.UpstreamServers
}
if len(req.DNSServer.DnssecUpstreamServers) > 0 {
s.globalConfig.DNS.DNSSECUpstreamDNS = req.DNSServer.DnssecUpstreamServers
}
if req.DNSServer.SaveInterval > 0 {
s.globalConfig.DNS.SaveInterval = req.DNSServer.SaveInterval
}
if req.DNSServer.Timeout > 0 {
s.globalConfig.DNS.QueryTimeout = req.DNSServer.Timeout
}
s.globalConfig.DNS.EnableIPv6 = req.DNSServer.EnableIPv6
s.globalConfig.DNS.EnableDNSSEC = req.DNSServer.EnableDNSSEC
// 更新查询模式
if req.DNSServer.QueryMode != "" {
s.globalConfig.DNS.QueryMode = req.DNSServer.QueryMode
}
// 更新缓存配置
if req.DNSServer.CacheMode != "" {
s.globalConfig.DNS.CacheMode = req.DNSServer.CacheMode
}
if req.DNSServer.CacheSize > 0 {
s.globalConfig.DNS.CacheSize = req.DNSServer.CacheSize
}
if req.DNSServer.MaxCacheTTL > 0 {
s.globalConfig.DNS.MaxCacheTTL = req.DNSServer.MaxCacheTTL
}
if req.DNSServer.MinCacheTTL > 0 {
s.globalConfig.DNS.MinCacheTTL = req.DNSServer.MinCacheTTL
}
// 更新enableFastReturn
if req.DNSServer.EnableFastReturn != nil {
s.globalConfig.DNS.EnableFastReturn = *req.DNSServer.EnableFastReturn
}
// 更新noDNSSECDomains
if len(req.DNSServer.NoDNSSECDomains) > 0 {
s.globalConfig.DNS.NoDNSSECDomains = req.DNSServer.NoDNSSECDomains
}
// 更新domainSpecificDNS
if req.DNSServer.DomainSpecificDNS != nil {
s.globalConfig.DNS.DomainSpecificDNS = req.DNSServer.DomainSpecificDNS
}
// 更新HTTP配置
if req.HTTPServer.Port > 0 {
s.globalConfig.HTTP.Port = req.HTTPServer.Port
}
if req.HTTPServer.Host != "" {
s.globalConfig.HTTP.Host = req.HTTPServer.Host
}
s.globalConfig.HTTP.EnableAPI = req.HTTPServer.EnableAPI
if req.HTTPServer.Username != "" {
s.globalConfig.HTTP.Username = req.HTTPServer.Username
}
if req.HTTPServer.Password != "" {
s.globalConfig.HTTP.Password = req.HTTPServer.Password
}
// 更新屏蔽配置
if req.Shield.BlockMethod != "" {
// 验证屏蔽方法是否有效
validMethods := map[string]bool{
"NXDOMAIN": true,
"refused": true,
"emptyIP": true,
"customIP": true,
}
if !validMethods[req.Shield.BlockMethod] {
http.Error(w, "无效的屏蔽方法", http.StatusBadRequest)
return
}
s.globalConfig.Shield.BlockMethod = req.Shield.BlockMethod
// 如果选择了customIP,验证IP地址
if req.Shield.BlockMethod == "customIP" {
if req.Shield.CustomBlockIP == "" {
http.Error(w, "自定义IP不能为空", http.StatusBadRequest)
return
}
// 简单的IP地址验证
if !isValidIP(req.Shield.CustomBlockIP) {
http.Error(w, "无效的IP地址", http.StatusBadRequest)
return
}
}
}
if req.Shield.CustomBlockIP != "" {
s.globalConfig.Shield.CustomBlockIP = req.Shield.CustomBlockIP
}
// 更新更新间隔
if req.Shield.UpdateInterval > 0 {
s.globalConfig.Shield.UpdateInterval = req.Shield.UpdateInterval
// 重新启动自动更新
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
}
// 更新统计保存间隔
if req.Shield.StatsSaveInterval > 0 {
s.globalConfig.Shield.StatsSaveInterval = req.Shield.StatsSaveInterval
}
// 更新GFWList配置
s.globalConfig.GFWList.IP = req.GFWList.IP
s.globalConfig.GFWList.Content = req.GFWList.Content
s.globalConfig.GFWList.Enabled = req.GFWList.Enabled
// 重新加载GFWList规则
if s.gfwManager != nil {
if err := s.gfwManager.LoadRules(); err != nil {
logger.Error("重新加载GFWList规则失败", "error", err)
}
}
// 更新黑名单配置
if req.Shield.Blacklists != nil {
// 验证黑名单配置
for i, bl := range req.Shield.Blacklists {
if bl.URL == "" {
http.Error(w, fmt.Sprintf("黑名单URL不能为空,索引: %d", i), http.StatusBadRequest)
return
}
if !strings.HasPrefix(bl.URL, "http://") && !strings.HasPrefix(bl.URL, "https://") {
http.Error(w, fmt.Sprintf("黑名单URL必须以http://或https://开头,索引: %d", i), http.StatusBadRequest)
return
}
}
s.globalConfig.Shield.Blacklists = req.Shield.Blacklists
s.shieldManager.UpdateBlacklist(req.Shield.Blacklists)
// 重新加载规则
if err := s.shieldManager.LoadRules(); err != nil {
logger.Error("重新加载规则失败", "error", err)
}
}
// 更新现有的DNSCache实例配置
// 最大和最小TTL(秒)
maxCacheTTL := time.Duration(s.globalConfig.DNS.MaxCacheTTL) * time.Second
minCacheTTL := time.Duration(s.globalConfig.DNS.MinCacheTTL) * time.Second
// 最大缓存大小(字节)
maxCacheSize := int64(s.globalConfig.DNS.CacheSize) * 1024 * 1024
// 更新缓存配置
s.dnsServer.DnsCache.SetMaxCacheTTL(maxCacheTTL)
s.dnsServer.DnsCache.SetMinCacheTTL(minCacheTTL)
s.dnsServer.DnsCache.SetCacheMode(s.globalConfig.DNS.CacheMode)
s.dnsServer.DnsCache.SetMaxCacheSize(maxCacheSize)
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.ini"); err != nil {
logger.Error("保存配置到文件失败", "error", err)
// 不返回错误,只记录日志,因为配置已经在内存中更新成功
}
// 返回成功响应
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "配置已更新",
})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// isValidIP 简单验证IP地址格式
func isValidIP(ip string) bool {
// 简单的IPv4地址验证
parts := strings.Split(ip, ".")
if len(parts) != 4 {
return false
}
for _, part := range parts {
// 检查是否为数字
for _, char := range part {
if char < '0' || char > '9' {
return false
}
}
// 检查数字范围
var num int
if _, err := fmt.Sscanf(part, "%d", &num); err != nil || num < 0 || num > 255 {
return false
}
}
return true
}
// checkURLExists 检查URL是否存在且可访问
func checkURLExists(url string) bool {
// 创建一个带有超时的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
}
// 发送HEAD请求来检查URL是否存在
resp, err := client.Head(url)
if err != nil {
return false
}
defer resp.Body.Close()
// 检查状态码,2xx和3xx表示成功
return resp.StatusCode >= 200 && resp.StatusCode < 400
}
// handleLogsStats 处理日志统计请求
func (s *Server) handleLogsStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 构建缓存键
cacheKey := "logs_stats"
// 如果启用缓存,先尝试从缓存获取
if s.cacheEnabled {
if cachedStats, found := s.statsCache.Get(cacheKey); found {
// 缓存命中,直接返回
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cachedStats)
return
}
}
// 缓存未命中,获取最新统计数据
logStats := s.dnsServer.GetQueryStats()
// 添加归档统计信息
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager != nil {
archives, err := archiveManager.GetArchiveList()
if err == nil {
var archiveTotalRecords int64 = 0
var archiveTotalCompressedSize int64 = 0
for _, archive := range archives {
archiveTotalRecords += archive.RecordCount
archiveTotalCompressedSize += archive.CompressedSize
}
logStats["archiveCount"] = len(archives)
logStats["archiveTotalRecords"] = archiveTotalRecords
logStats["archiveTotalCompressedSize"] = archiveTotalCompressedSize
logStats["archiveTotalSize"] = archiveTotalCompressedSize // 压缩后的大小
// 如果有主库统计,计算总记录数
if totalRecords, ok := logStats["totalQueries"].(int64); ok {
logStats["grandTotalRecords"] = totalRecords + archiveTotalRecords
}
}
}
// 存入缓存
if s.cacheEnabled {
s.statsCache.Set(cacheKey, logStats)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(logStats)
}
// handleLogsQuery 处理日志查询请求
func (s *Server) handleLogsQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取查询参数
limit := 30 // 默认返回 30 条日志
pageNum := 1 // 默认第 1 页
sortField := r.URL.Query().Get("sort")
sortDirection := r.URL.Query().Get("direction")
resultFilter := r.URL.Query().Get("result")
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
fmt.Sscanf(limitStr, "%d", &limit)
}
// 支持 page 参数(优先)或 offset 参数
if pageStr := r.URL.Query().Get("page"); pageStr != "" {
fmt.Sscanf(pageStr, "%d", &pageNum)
}
offset := (pageNum - 1) * limit
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 构建缓存键,包含所有查询参数(排除 offset,因为分页数据可以复用)
cacheKey := fmt.Sprintf("logs_query_%s_%s_%s_%s_%s_%d", sortField, sortDirection, resultFilter, searchTerm, queryType, limit)
// 尝试从缓存获取
if s.cacheEnabled {
if cachedLogs, found := s.queryCache.Get(cacheKey); found {
// 缓存命中,直接返回
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(cachedLogs)
return
}
}
// 构建过滤条件和分页参数
filter := log.LogFilter{
Result: resultFilter,
SearchTerm: searchTerm,
QueryType: queryType,
}
pageParams := log.PageParams{
Limit: limit,
Offset: offset,
SortField: sortField,
SortDirection: sortDirection,
}
var logs []log.QueryLog
var total int64
var err error
// 优先使用归档查询引擎(如果可用)
archiveQueryEngine := s.dnsServer.GetArchiveQueryEngine()
if archiveQueryEngine != nil {
logs, total, err = archiveQueryEngine.QueryLogs(filter, pageParams)
if err == nil {
// 归档查询成功,直接返回
totalPages := int64(0)
if limit > 0 {
totalPages = (total + int64(limit) - 1) / int64(limit)
}
response := map[string]interface{}{
"logs": logs,
"total": total,
"page": pageNum,
"limit": limit,
"totalPages": totalPages,
}
// 存入缓存(只缓存第一页)
if s.cacheEnabled && pageNum == 1 {
s.queryCache.Set(cacheKey, response)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
return
}
logger.Error("归档查询失败,降级到普通查询", "error", err)
}
// 使用原有查询方法
dnsLogs := s.dnsServer.GetQueryLogs(limit, offset, sortField, sortDirection, resultFilter, searchTerm, queryType)
// 转换为 log.QueryLog 格式
logs = make([]log.QueryLog, len(dnsLogs))
for i, logItem := range dnsLogs {
// 将 Answers 转换为 JSON 字符串
answersJSON, _ := json.Marshal(logItem.Answers)
logs[i] = log.QueryLog{
Timestamp: logItem.Timestamp,
ClientIP: logItem.ClientIP,
Domain: logItem.Domain,
QueryType: logItem.QueryType,
ResponseTime: logItem.ResponseTime,
Result: logItem.Result,
BlockRule: logItem.BlockRule,
BlockType: logItem.BlockType,
FromCache: logItem.FromCache,
DNSSEC: logItem.DNSSEC,
EDNS: logItem.EDNS,
DNSServer: logItem.DNSServer,
DNSSECServer: logItem.DNSSECServer,
Answers: string(answersJSON),
ResponseCode: logItem.ResponseCode,
}
}
// 获取总记录数(用于计算总页数)
total = int64(s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType))
// 计算总页数
totalPages := int64(0)
if limit > 0 {
totalPages = (total + int64(limit) - 1) / int64(limit)
}
// 构建响应,包含分页信息
response := map[string]interface{}{
"logs": logs,
"total": total,
"page": pageNum,
"limit": limit,
"totalPages": totalPages,
}
// 为日志添加域名信息(只在第一页添加,减少开销)
if pageNum == 1 && len(logs) > 0 {
enrichedLogs := s.enrichLogsWithDomainInfo(logs)
response["logs"] = enrichedLogs
}
// 存入缓存(只缓存第一页,因为用户最常查看第一页)
if s.cacheEnabled && pageNum == 1 {
s.queryCache.Set(cacheKey, response)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// handleLogsCount 处理日志总数请求
func (s *Server) handleLogsCount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取过滤参数
resultFilter := r.URL.Query().Get("result")
searchTerm := r.URL.Query().Get("search")
queryType := r.URL.Query().Get("queryType")
// 构建缓存键
cacheKey := fmt.Sprintf("logs_count_%s_%s_%s", resultFilter, searchTerm, queryType)
// 尝试从缓存获取
if s.cacheEnabled {
if cachedCount, found := s.queryCache.Get(cacheKey); found {
// 缓存命中,直接返回
if count, ok := cachedCount.(int); ok {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
return
}
}
}
// 缓存未命中,获取带过滤条件的日志总数
count := s.dnsServer.GetQueryLogsCountWithFilter(resultFilter, searchTerm, queryType)
// 存入缓存
if s.cacheEnabled {
s.queryCache.Set(cacheKey, count)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"count": count})
}
// handleArchiveList 处理归档列表请求
func (s *Server) handleArchiveList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取归档管理器
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager == nil {
http.Error(w, "归档功能未启用", http.StatusNotFound)
return
}
// 获取归档列表
archives, err := archiveManager.GetArchiveList()
if err != nil {
logger.Error("获取归档列表失败", "error", err)
http.Error(w, fmt.Sprintf("获取归档列表失败:%v", err), http.StatusInternalServerError)
return
}
// 返回归档列表
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(archives)
}
// handleArchiveCleanup 处理归档清理请求
func (s *Server) handleArchiveCleanup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取归档管理器
archiveManager := s.dnsServer.GetArchiveManager()
if archiveManager == nil {
http.Error(w, "归档功能未启用", http.StatusNotFound)
return
}
// 执行清理
deleted, err := archiveManager.CleanupOldArchives()
if err != nil {
logger.Error("清理归档失败", "error", err)
http.Error(w, fmt.Sprintf("清理归档失败:%v", err), http.StatusInternalServerError)
return
}
// 返回清理结果
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"deleted": deleted,
"message": fmt.Sprintf("成功清理 %d 个归档文件", deleted),
})
}
// handleRestart 处理重启服务请求
func (s *Server) handleRestart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("收到重启服务请求")
// 停止DNS服务器
s.dnsServer.Stop()
// 重新加载屏蔽规则
if err := s.shieldManager.LoadRules(); err != nil {
logger.Error("重新加载屏蔽规则失败", "error", err)
}
// 重新启动DNS服务器
go func() {
if err := s.dnsServer.Start(); err != nil {
logger.Error("DNS服务器重启失败", "error", err)
}
}()
// 重新启动定时更新任务
s.shieldManager.StopAutoUpdate()
s.shieldManager.StartAutoUpdate()
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "服务已重启"})
logger.Info("服务重启成功")
}
// handleLogin 处理登录请求
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var loginData struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&loginData); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
return
}
// 验证用户名和密码
if loginData.Username != s.config.Username || loginData.Password != s.config.Password {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "用户名或密码错误"})
return
}
// 生成会话ID
sessionID := fmt.Sprintf("%d_%d", time.Now().UnixNano(), len(s.sessions))
// 保存会话
s.sessionsMutex.Lock()
s.sessions[sessionID] = time.Now().Add(s.sessionTTL)
s.sessionsMutex.Unlock()
// 设置Cookie
cookie := &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
Expires: time.Now().Add(s.sessionTTL),
HttpOnly: true,
Secure: false, // 开发环境下使用false,生产环境应使用true
}
http.SetCookie(w, cookie)
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "登录成功"})
logger.Info(fmt.Sprintf("用户 %s 登录成功", loginData.Username))
}
// handleLogout 处理注销请求
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 从Cookie中获取会话ID
cookie, err := r.Cookie("session_id")
if err == nil {
// 删除会话
s.sessionsMutex.Lock()
delete(s.sessions, cookie.Value)
s.sessionsMutex.Unlock()
}
// 清除Cookie
clearCookie := &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: false,
}
http.SetCookie(w, clearCookie)
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "注销成功"})
logger.Info("用户注销成功")
}
// handleDomainInfoList 处理域名信息列表请求
func (s *Server) handleDomainInfoList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取查询参数
query := r.URL.Query()
if query.Has("domains") {
// 处理域名信息,支持过滤特定域名
domainFilter := query.Get("domains")
handleDomainsInfo(w, domainFilter)
} else if query.Has("trackers") {
// 处理跟踪器信息,支持过滤特定域名
trackerFilter := query.Get("trackers")
handleTrackersInfo(w, trackerFilter)
} else if query.Has("threats") {
// 处理威胁域名信息,支持过滤特定域名
threatFilter := query.Get("threats")
handleThreatsInfo(w, threatFilter)
} else {
// 获取所有域名信息列表
w.Header().Set("Content-Type", "application/json")
if s.domainInfoManager == nil {
json.NewEncoder(w).Encode(map[string]interface{}{
"lists": []interface{}{},
"domainInfoCount": 0,
"threatCount": 0,
"trackerCount": 0,
"lastUpdateTime": "从未更新",
})
return
}
// 使用 goroutine 和 channel 避免死锁
done := make(chan map[string]interface{}, 1)
go func() {
domainInfo := s.domainInfoManager.GetAllDomainInfo()
done <- domainInfo
}()
select {
case domainInfo := <-done:
json.NewEncoder(w).Encode(domainInfo)
case <-time.After(10 * time.Second):
// 超时,返回空响应
json.NewEncoder(w).Encode(map[string]interface{}{
"lists": []interface{}{},
"domainInfoCount": 0,
"threatCount": 0,
"trackerCount": 0,
"lastUpdateTime": "加载中...",
})
}
}
}
// isService 判断一个对象是否是服务(而不是分组)
func isService(obj map[string]interface{}) bool {
// 服务通常包含 name、url、categoryId 字段
_, hasName := obj["name"]
_, hasUrl := obj["url"]
_, hasCategoryId := obj["categoryId"]
// 如果有 name 和 url,则认为是服务
if hasName && hasUrl {
return true
}
// 如果有 categoryId,也认为是服务
if hasCategoryId {
return true
}
return false
}
// handleAlert 处理威胁告警请求
func (s *Server) handleAlert(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 获取告警列表
limit := 100
offset := 0
level := r.URL.Query().Get("level")
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
fmt.Sscanf(limitStr, "%d", &limit)
}
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
fmt.Sscanf(offsetStr, "%d", &offset)
}
// 获取告警列表
alerts := s.dnsServer.GetAlerts(limit, offset, level)
// 构建响应
response := map[string]interface{}{
"alerts": alerts,
"total": s.dnsServer.GetAlertCount(level),
}
json.NewEncoder(w).Encode(response)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleAlertResolve 处理威胁告警解决请求
func (s *Server) handleAlertResolve(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var req struct {
AlertID string `json:"alertId"`
Action string `json:"action"` // blocked, allowed
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.AlertID == "" || req.Action == "" {
http.Error(w, "AlertID and Action are required", http.StatusBadRequest)
return
}
// 验证动作
if req.Action != threat.ActionBlocked && req.Action != threat.ActionAllowed {
http.Error(w, "Invalid action", http.StatusBadRequest)
return
}
// 解决告警
success := s.dnsServer.ResolveAlert(req.AlertID, req.Action)
if success {
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
} else {
http.Error(w, "Failed to resolve alert", http.StatusInternalServerError)
}
}
// handleThreatDomain 处理威胁域名管理请求
func (s *Server) handleThreatDomain(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodGet:
// 获取所有威胁域名
threats := s.dnsServer.GetThreatDomains()
json.NewEncoder(w).Encode(threats)
case http.MethodPost:
// 添加威胁域名
var req struct {
Type string `json:"type"`
Name string `json:"name"`
RiskLevel int `json:"riskLevel"`
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Domain == "" {
http.Error(w, "Domain is required", http.StatusBadRequest)
return
}
// 设置默认值
if req.Type == "" {
req.Type = "未知"
}
if req.Name == "" {
req.Name = "未知"
}
if req.RiskLevel == 0 {
req.RiskLevel = 1
}
err := s.dnsServer.AddThreatDomain(req.Type, req.Name, req.RiskLevel, req.Domain)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
case http.MethodDelete:
// 删除威胁域名
var req struct {
Domain string `json:"domain"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Domain == "" {
http.Error(w, "Domain is required", http.StatusBadRequest)
return
}
err := s.dnsServer.RemoveThreatDomain(req.Domain)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(map[string]string{"status": "success"})
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// processServiceItem 递归处理服务或分组
func processServiceItem(
serviceName string,
service interface{},
companyLevelCompany string,
domainFilter string,
categories map[string]string,
result *[]map[string]interface{},
) {
serviceMap, ok := service.(map[string]interface{})
if !ok {
return
}
// 跳过 company 字段
if serviceName == "company" {
return
}
// 判断是服务还是分组
if isService(serviceMap) {
// 这是一个服务,进行处理
urlValue := serviceMap["url"]
match := false
// 检查是否需要过滤
if domainFilter != "" {
// 检查服务名称是否包含过滤条件
if serviceName == domainFilter {
match = true
} else {
// 检查 URL 是否包含过滤条件
switch v := urlValue.(type) {
case string:
if strings.Contains(v, domainFilter) {
match = true
}
case map[string]interface{}:
for _, url := range v {
if urlStr, ok := url.(string); ok && strings.Contains(urlStr, domainFilter) {
match = true
break
}
}
}
}
if !match {
return
}
}
// 确定公司名:优先使用服务级别的 company 字段,否则使用公司级别的 company 字段
itemCompany := companyLevelCompany
if serviceCompany, ok := serviceMap["company"].(string); ok {
itemCompany = serviceCompany
}
// 构建响应对象
item := map[string]interface{}{
"icon": serviceMap["icon"],
"name": serviceMap["name"],
"company": itemCompany,
}
// 添加类别
if categoryId, ok := serviceMap["categoryId"].(float64); ok {
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
if category, exists := categories[categoryIdStr]; exists {
item["category"] = category
}
}
*result = append(*result, item)
} else {
// 这是一个分组,递归处理其下的子项
for subName, subService := range serviceMap {
processServiceItem(subName, subService, companyLevelCompany, domainFilter, categories, result)
}
}
}
func handleDomainsInfo(w http.ResponseWriter, domainFilter string) {
// 如果过滤参数为空字符串,返回空数组
if domainFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
return
}
filePath := "./static/domain-info/domains/domain-info.json"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read domain info file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取域名信息文件失败: %v", err))
return
}
// 解析JSON
var domainInfo struct {
Categories map[string]string `json:"categories"`
Domains map[string]map[string]interface{} `json:"domains"`
}
if err := json.Unmarshal(data, &domainInfo); err != nil {
http.Error(w, "Failed to parse domain info file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析域名信息文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]interface{}
for _, services := range domainInfo.Domains {
// 获取公司级别的 company 字段
companyLevelCompany := ""
if companyData, ok := services["company"].(string); ok {
companyLevelCompany = companyData
}
// 遍历所有服务(包括嵌套的分组)
for serviceName, service := range services {
processServiceItem(serviceName, service, companyLevelCompany, domainFilter, domainInfo.Categories, &result)
}
}
// 返回 JSON 响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleTrackersInfo 处理跟踪器信息请求,返回名称、类别、url、所属单位/公司
func handleTrackersInfo(w http.ResponseWriter, trackerFilter string) {
// 如果过滤参数为空字符串,返回空数组
if trackerFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]interface{}{})
return
}
filePath := "./static/domain-info/tracker/trackers.json"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read trackers file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取跟踪器文件失败: %v", err))
return
}
// 解析JSON
var trackersInfo struct {
Categories map[string]string `json:"categories"`
Trackers map[string]map[string]interface{} `json:"trackers"`
}
if err := json.Unmarshal(data, &trackersInfo); err != nil {
http.Error(w, "Failed to parse trackers file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析跟踪器文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]interface{}
for trackerDomain, tracker := range trackersInfo.Trackers {
// 检查是否需要过滤
if trackerFilter != "" {
// 检查跟踪器域名是否包含过滤条件
if !strings.Contains(trackerDomain, trackerFilter) {
// 检查名称是否包含过滤条件
if name, ok := tracker["name"].(string); !ok || !strings.Contains(name, trackerFilter) {
// 检查URL是否包含过滤条件
if url, ok := tracker["url"].(string); !ok || !strings.Contains(url, trackerFilter) {
continue
}
}
}
}
item := map[string]interface{}{
"name": tracker["name"],
"url": tracker["url"],
"company": tracker["companyId"],
}
// 添加类别
if categoryId, ok := tracker["categoryId"].(float64); ok {
categoryIdStr := fmt.Sprintf("%.0f", categoryId)
if category, exists := trackersInfo.Categories[categoryIdStr]; exists {
item["category"] = category
}
}
result = append(result, item)
}
// 返回JSON响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleThreatsInfo 处理威胁域名信息请求,返回类型、名称、级别、域名
func handleThreatsInfo(w http.ResponseWriter, threatFilter string) {
// 如果过滤参数为空字符串,返回空数组
if threatFilter == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]string{})
return
}
filePath := "./static/domain-info/threats/threats-database.csv"
data, err := os.ReadFile(filePath)
if err != nil {
http.Error(w, "Failed to read threats file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("读取威胁域名文件失败: %v", err))
return
}
// 解析CSV
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1 // 允许不同长度的记录
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
http.Error(w, "Failed to parse threats file", http.StatusInternalServerError)
logger.Error(fmt.Sprintf("解析威胁域名文件失败: %v", err))
return
}
// 转换为所需格式
var result []map[string]string
// 跳过标题行
for i, record := range records {
if i == 0 {
continue
}
if len(record) >= 4 {
// 检查是否需要过滤
if threatFilter != "" {
// 检查域名是否包含过滤条件
if !strings.Contains(record[3], threatFilter) {
// 检查名称是否包含过滤条件
if !strings.Contains(record[1], threatFilter) {
// 检查类型是否包含过滤条件
if !strings.Contains(record[0], threatFilter) {
continue
}
}
}
}
item := map[string]string{
"type": record[0],
"name": record[1],
"level": record[2],
"domain": record[3],
}
result = append(result, item)
}
}
// 返回JSON响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleChangePassword 处理修改密码请求
func (s *Server) handleChangePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求体
var changePasswordData struct {
CurrentPassword string `json:"currentPassword"`
NewPassword string `json:"newPassword"`
}
if err := json.NewDecoder(r.Body).Decode(&changePasswordData); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "无效的请求体"})
return
}
// 验证当前密码
if changePasswordData.CurrentPassword != s.config.Password {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "当前密码错误"})
return
}
// 更新密码
s.config.Password = changePasswordData.NewPassword
// 保存配置到文件
if err := saveConfigToFile(s.globalConfig, "./config.json"); err != nil {
logger.Error("保存配置文件失败", "error", err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "保存密码失败"})
return
}
// 返回成功响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "success", "message": "密码修改成功"})
logger.Info("密码修改成功")
}
// handleThreatQuery 处理威胁域名查询请求
// @Summary 查询威胁域名信息
// @Description 根据传入的域名参数查询威胁数据库,返回威胁类型、名称、风险等级和域名
// @Tags threat
// @Accept json
// @Produce json
// @Param domain query string true "要查询的域名"
// @Success 200 {string} string "威胁信息,格式:类型,名称,风险等级,域名"
// @Failure 400 {object} map[string]string "缺少域名参数"
// @Router /api/threat [get]
func (s *Server) handleThreatQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 获取域名参数
domain := r.URL.Query().Get("domain")
if domain == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "需要提供 domain 参数"})
return
}
// 使用域名信息管理器查询威胁信息
threatInfo := s.domainInfoManager.GetThreatInfo(domain)
result := map[string]interface{}{
"domain": domain,
}
if threatInfo != nil {
// 从内存中的威胁数据库获取信息
result["isThreat"] = true
result["data"] = threatInfo
result["threatType"] = threatInfo["type"]
result["threatName"] = threatInfo["name"]
result["riskLevel"] = threatInfo["riskLevel"]
} else {
// 检查子域名匹配
matched := false
// 遍历威胁数据库查找匹配的顶级域名规则
s.domainInfoManager.GetThreatInfo("") // 这个调用会触发遍历,需要重新实现
// 简单的子域名匹配逻辑
parts := strings.Split(domain, ".")
for i := range parts {
if i == 0 {
continue
}
subDomain := strings.Join(parts[i:], ".")
threatInfo = s.domainInfoManager.GetThreatInfo(subDomain)
if threatInfo != nil {
result["isThreat"] = true
result["data"] = threatInfo
result["threatType"] = threatInfo["type"]
result["threatName"] = threatInfo["name"]
result["riskLevel"] = threatInfo["riskLevel"]
result["matchedDomain"] = subDomain
matched = true
break
}
}
if !matched {
result["isThreat"] = false
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
}
// handleThreatBatch 批量查询威胁域名
// @Summary 批量查询威胁域名
// @Description 批量查询多个域名是否是威胁域名
// @Tags threat
// @Accept json
// @Produce json
// @Param domains body []string true "域名列表"
// @Success 200 {object} map[string]interface{} "批量查询结果"
// @Router /api/threat/batch [post]
func (s *Server) handleThreatBatch(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Domains []string `json:"domains"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{"error": "请求格式错误"})
return
}
// 读取威胁数据库 CSV 文件
filePath := "./static/domain-info/threats/threats-database.csv"
data, err := os.ReadFile(filePath)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "读取威胁数据库失败"})
logger.Error(fmt.Sprintf("读取威胁数据库文件失败:%v", err))
return
}
// 解析 CSV
reader := csv.NewReader(bytes.NewReader(data))
reader.FieldsPerRecord = -1 // 允许不同长度的记录
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": "解析威胁数据库失败"})
logger.Error(fmt.Sprintf("解析威胁数据库文件失败:%v", err))
return
}
// 构建威胁域名映射(支持顶级域名匹配)
threatMap := make(map[string][]string)
for i, record := range records {
if i == 0 {
continue // 跳过标题行
}
if len(record) >= 4 {
threatType := record[0] // 第一列:类型
threatName := record[1] // 第二列:名称
riskLevel := record[2] // 第三列:风险等级
domain := record[3] // 第四列:域名
threatInfo := []string{threatType, threatName, riskLevel}
// 1. 完整域名匹配(所有类型都添加)
threatMap[domain] = threatInfo
// 2. 只有恶意网站类型才添加子域名匹配规则
// 类型判断:钓鱼网站、仿冒网站
// 逻辑:如果威胁数据库中有 sub.example.com,则所有子域名(a.sub.example.com)都应匹配
if threatType == "钓鱼网站" || threatType == "仿冒网站" {
// 对于恶意网站,添加子域名匹配规则
// 例如:sub.example.com -> 添加 .sub.example.com 规则
// 这样 a.sub.example.com 就会匹配
topLevelDomain := "." + domain
// 只有当该顶级域名规则不存在时才添加
if _, exists := threatMap[topLevelDomain]; !exists {
threatMap[topLevelDomain] = threatInfo
}
}
}
}
// 批量查询
results := make([]map[string]interface{}, 0, len(req.Domains))
for _, domain := range req.Domains {
// 1. 先检查完整匹配
if threat, exists := threatMap[domain]; exists {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": true,
"data": fmt.Sprintf("%s,%s,%s,%s", threat[0], threat[1], threat[2], domain),
})
continue
}
// 2. 检查子域名匹配(遍历顶级域名规则)
matched := false
for threatDomain, threatInfo := range threatMap {
// 只检查以点开头的顶级域名规则
if strings.HasPrefix(threatDomain, ".") && strings.HasSuffix(domain, threatDomain) {
// 验证:确保是有效的子域名匹配
// 例如:test.example.com 匹配 .example.com ✅
// notexample.com 不应该匹配 .example.com ❌
// 去掉 threatDomain 的第一个字符(即去掉开头的点)
suffixToTrim := threatDomain[1:]
prefix := strings.TrimSuffix(domain, suffixToTrim)
// 验证逻辑:前缀不为空且以.结尾,或者前缀为空(完全匹配)
if len(prefix) == 0 || (len(prefix) > 0 && strings.HasSuffix(prefix, ".")) {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": true,
"data": fmt.Sprintf("%s,%s,%s,%s", threatInfo[0], threatInfo[1], threatInfo[2], domain),
})
matched = true
break
}
}
}
if !matched {
results = append(results, map[string]interface{}{
"domain": domain,
"isThreat": false,
})
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"results": results,
})
}